(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False, plot_log=False)
| 9 | |
| 10 | |
| 11 | def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False, plot_log=False): |
| 12 | if isinstance(alignment, torch.Tensor): |
| 13 | alignment_ = alignment.detach().cpu().numpy().squeeze() |
| 14 | else: |
| 15 | alignment_ = alignment |
| 16 | alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ |
| 17 | fig, ax = plt.subplots(figsize=fig_size) |
| 18 | im = ax.imshow( |
| 19 | alignment_.T, aspect="auto", origin="lower", interpolation="none", norm=LogNorm() if plot_log else None |
| 20 | ) |
| 21 | fig.colorbar(im, ax=ax) |
| 22 | xlabel = "Decoder timestep" |
| 23 | if info is not None: |
| 24 | xlabel += "\n\n" + info |
| 25 | plt.xlabel(xlabel) |
| 26 | plt.ylabel("Encoder timestep") |
| 27 | # plt.yticks(range(len(text)), list(text)) |
| 28 | plt.tight_layout() |
| 29 | if title is not None: |
| 30 | plt.title(title) |
| 31 | if not output_fig: |
| 32 | plt.close() |
| 33 | return fig |
| 34 | |
| 35 | |
| 36 | def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): |
no outgoing calls
searching dependent graphs…