Add scripts for Flowtron

- Python script to generate Flowtron filelists
- Flowtron patch to enable inference from models
This commit is contained in:
Vsevolod Kremianskii 2021-04-14 10:24:17 +07:00
parent c3f857a9ee
commit 0b226fce6e
3 changed files with 90 additions and 1 deletions

View file

@ -21,7 +21,7 @@ def get_unique_json_values(extract_dir, path_pattern, extract_values):
# Recursively process files in extraction directory, whose path matches a pattern
for f in glob.glob("{}/**".format(extract_dir), recursive=True):
if re.search(path_pattern, f):
with open(f, 'r') as fp:
with open(f, "r") as fp:
obj = json.load(fp)
for value in extract_values(obj):
count_value(value, values)

61
scripts/filelists.py Normal file
View file

@ -0,0 +1,61 @@
"""Script to generate a filelist for training a Flowtron model of a particular NPC."""
import glob
import json
import os
import random
import sys
extract_dir = r"D:\OpenKotOR\Extract\KotOR"
wav_dir = r"D:\OpenKotOR\TTS\bastila\train"
if not os.path.exists(extract_dir):
raise RuntimeError("Extraction directory does not exist")
if not os.path.exists(wav_dir):
raise RuntimeError("WAV directory does not exist")
def get_lines_from_tlk(obj, speaker):
lines = []
if "strings" in obj:
uniq_sound = set()
for string in obj["strings"]:
if ("soundResRef" in string) and (speaker in string["soundResRef"]):
soundresref = string["soundResRef"].lower()
text = string["text"]
if soundresref.startswith("n") and (not (text.startswith("[") and text.endswith("]"))) and (not soundresref in uniq_sound):
wav_filename = os.path.join(wav_dir, soundresref + ".wav")
if os.path.exists(wav_filename):
lines.append("{}|{}|0\n".format(wav_filename, text))
uniq_sound.add(soundresref)
return lines
def generate_filelist(extract_dir, speaker):
# Extract lines from all TLK files
lines = []
for f in glob.glob("{}/**".format(extract_dir), recursive=True):
if f.endswith(".tlk.json"):
with open(f, "r") as fp:
obj = json.load(fp)
lines.extend(get_lines_from_tlk(obj, speaker))
# Split lines into training and validation filelists
random.shuffle(lines)
num_val = int(5 * len(lines) / 100)
lines_train = lines[num_val:]
lines_val = lines[:num_val]
with open(speaker + "_train_filelist.txt", "w") as fp:
fp.writelines(lines_train)
with open(speaker + "_val_filelist.txt", "w") as fp:
fp.writelines(lines_val)
if len(sys.argv) > 1:
generate_filelist(extract_dir, sys.argv[1])
else:
print('Usage: python ttsfilelist.py speaker')

View file

@ -0,0 +1,28 @@
diff --git a/inference.py b/inference.py
index 791b09a..ebfac4e 100644
--- a/inference.py
+++ b/inference.py
@@ -51,7 +51,11 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
# load flowtron
model = Flowtron(**model_config).cuda()
- state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
+ state_dict = torch.load(flowtron_path, map_location='cpu')
+ if 'model' in state_dict:
+ state_dict = state_dict['model'].state_dict()
+ else:
+ state_dict = state_dict['state_dict']
model.load_state_dict(state_dict)
model.eval()
print("Loaded checkpoint '{}')" .format(flowtron_path))
@@ -73,8 +77,8 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
for k in range(len(attentions)):
attention = torch.cat(attentions[k]).cpu().numpy()
fig, axes = plt.subplots(1, 2, figsize=(16, 4))
- axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
- axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto')
+ axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto')
+ axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto')
fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k)))
plt.close("all")