Add scripts for Flowtron
- Python script to generate Flowtron filelists - Flowtron patch to enable inference from models
This commit is contained in:
parent
c3f857a9ee
commit
0b226fce6e
3 changed files with 90 additions and 1 deletions
|
@ -21,7 +21,7 @@ def get_unique_json_values(extract_dir, path_pattern, extract_values):
|
|||
# Recursively process files in extraction directory, whose path matches a pattern
|
||||
for f in glob.glob("{}/**".format(extract_dir), recursive=True):
|
||||
if re.search(path_pattern, f):
|
||||
with open(f, 'r') as fp:
|
||||
with open(f, "r") as fp:
|
||||
obj = json.load(fp)
|
||||
for value in extract_values(obj):
|
||||
count_value(value, values)
|
||||
|
|
61
scripts/filelists.py
Normal file
61
scripts/filelists.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
"""Script to generate a filelist for training a Flowtron model of a particular NPC."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
extract_dir = r"D:\OpenKotOR\Extract\KotOR"
|
||||
wav_dir = r"D:\OpenKotOR\TTS\bastila\train"
|
||||
|
||||
|
||||
if not os.path.exists(extract_dir):
|
||||
raise RuntimeError("Extraction directory does not exist")
|
||||
|
||||
if not os.path.exists(wav_dir):
|
||||
raise RuntimeError("WAV directory does not exist")
|
||||
|
||||
|
||||
def get_lines_from_tlk(obj, speaker):
|
||||
lines = []
|
||||
if "strings" in obj:
|
||||
uniq_sound = set()
|
||||
for string in obj["strings"]:
|
||||
if ("soundResRef" in string) and (speaker in string["soundResRef"]):
|
||||
soundresref = string["soundResRef"].lower()
|
||||
text = string["text"]
|
||||
if soundresref.startswith("n") and (not (text.startswith("[") and text.endswith("]"))) and (not soundresref in uniq_sound):
|
||||
wav_filename = os.path.join(wav_dir, soundresref + ".wav")
|
||||
if os.path.exists(wav_filename):
|
||||
lines.append("{}|{}|0\n".format(wav_filename, text))
|
||||
uniq_sound.add(soundresref)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def generate_filelist(extract_dir, speaker):
|
||||
# Extract lines from all TLK files
|
||||
lines = []
|
||||
for f in glob.glob("{}/**".format(extract_dir), recursive=True):
|
||||
if f.endswith(".tlk.json"):
|
||||
with open(f, "r") as fp:
|
||||
obj = json.load(fp)
|
||||
lines.extend(get_lines_from_tlk(obj, speaker))
|
||||
|
||||
# Split lines into training and validation filelists
|
||||
random.shuffle(lines)
|
||||
num_val = int(5 * len(lines) / 100)
|
||||
lines_train = lines[num_val:]
|
||||
lines_val = lines[:num_val]
|
||||
|
||||
with open(speaker + "_train_filelist.txt", "w") as fp:
|
||||
fp.writelines(lines_train)
|
||||
with open(speaker + "_val_filelist.txt", "w") as fp:
|
||||
fp.writelines(lines_val)
|
||||
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
generate_filelist(extract_dir, sys.argv[1])
|
||||
else:
|
||||
print('Usage: python ttsfilelist.py speaker')
|
28
scripts/infer_from_model.patch
Normal file
28
scripts/infer_from_model.patch
Normal file
|
@ -0,0 +1,28 @@
|
|||
diff --git a/inference.py b/inference.py
|
||||
index 791b09a..ebfac4e 100644
|
||||
--- a/inference.py
|
||||
+++ b/inference.py
|
||||
@@ -51,7 +51,11 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
|
||||
|
||||
# load flowtron
|
||||
model = Flowtron(**model_config).cuda()
|
||||
- state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
|
||||
+ state_dict = torch.load(flowtron_path, map_location='cpu')
|
||||
+ if 'model' in state_dict:
|
||||
+ state_dict = state_dict['model'].state_dict()
|
||||
+ else:
|
||||
+ state_dict = state_dict['state_dict']
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
print("Loaded checkpoint '{}')" .format(flowtron_path))
|
||||
@@ -73,8 +77,8 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
|
||||
for k in range(len(attentions)):
|
||||
attention = torch.cat(attentions[k]).cpu().numpy()
|
||||
fig, axes = plt.subplots(1, 2, figsize=(16, 4))
|
||||
- axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
|
||||
- axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto')
|
||||
+ axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto')
|
||||
+ axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto')
|
||||
fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k)))
|
||||
plt.close("all")
|
||||
|
Loading…
Reference in a new issue