diff --git a/inference.py b/inference.py index 791b09a..ebfac4e 100644 --- a/inference.py +++ b/inference.py @@ -51,7 +51,11 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames, # load flowtron model = Flowtron(**model_config).cuda() - state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] + state_dict = torch.load(flowtron_path, map_location='cpu') + if 'model' in state_dict: + state_dict = state_dict['model'].state_dict() + else: + state_dict = state_dict['state_dict'] model.load_state_dict(state_dict) model.eval() print("Loaded checkpoint '{}')" .format(flowtron_path)) @@ -73,8 +77,8 @@ def infer(flowtron_path, waveglow_path, output_dir, text, speaker_id, n_frames, for k in range(len(attentions)): attention = torch.cat(attentions[k]).cpu().numpy() fig, axes = plt.subplots(1, 2, figsize=(16, 4)) - axes[0].imshow(mels[0].cpu().numpy(), origin='bottom', aspect='auto') - axes[1].imshow(attention[:, 0].transpose(), origin='bottom', aspect='auto') + axes[0].imshow(mels[0].cpu().numpy(), origin='lower', aspect='auto') + axes[1].imshow(attention[:, 0].transpose(), origin='lower', aspect='auto') fig.savefig(os.path.join(output_dir, 'sid{}_sigma{}_attnlayer{}.png'.format(speaker_id, sigma, k))) plt.close("all")