28 lines
1.4 KiB
Diff
28 lines
1.4 KiB
Diff
diff --git a/inference.py b/inference.py
|
|
index 791b09a..ebfac4e 100644
|
|
--- a/inference.py
|
|
+++ b/inference.py
|
|
@@ -51,7 +51,11 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
|
|
|
|
# load flowtron
|
|
model = Flowtron(**model_config).cuda()
|
|
- state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
|
|
+ state_dict = torch.load(flowtron_path, map_location='cpu')
|
|
+ if 'model' in state_dict:
|
|
+ state_dict = state_dict['model'].state_dict()
|
|
+ else:
|
|
+ state_dict = state_dict['state_dict']
|
|
model.load_state_dict(state_dict)
|
|
model.eval()
|
|
print("Loaded checkpoint '{}')" .format(flowtron_path))
|
|
@@ -73,8 +77,8 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames,
|
|
for k in range(len(attentions)):
|
|
attention = torch.cat(attentions[k]).cpu().numpy()
|
|
fig, axes = plt.subplots(1, 2, figsize=(16, 4))
|
|
- axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto')
|
|
- axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto')
|
|
+ axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto')
|
|
+ axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto')
|
|
fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k)))
|
|
plt.close("all")
|
|
|