reone/scripts/flowtron/infer_from_model.patch

29 lines
1.4 KiB
Diff
Raw Normal View History

diff --git a/inference.py b/inference.py
index 791b09a..ebfac4e 100644
--- a/inference.py
+++ b/inference.py
@@ -51,7 +51,11 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
# load flowtron
model = Flowtron(**model_config).cuda()
- state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
+ state_dict = torch.load(flowtron_path, map_location='cpu')
+ if 'model' in state_dict:
+ state_dict = state_dict['model'].state_dict()
+ else:
+ state_dict = state_dict['state_dict']
model.load_state_dict(state_dict)
model.eval()
print("Loaded checkpoint '{}')" .format(flowtron_path))
@@ -73,8 +77,8 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
for k in range(len(attentions)):
attention = torch.cat(attentions[k]).cpu().numpy()
fig, axes = plt.subplots(1, 2, figsize=(16, 4))
- axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
- axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto')
+ axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto')
+ axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto')
fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k)))
plt.close("all")