| 17 | |
| 18 | |
| 19 | def plot(heads, intermediates, name): |
| 20 | fig, ax = plt.subplots(1, 2, facecolor='white', figsize=( |
| 21 | 10, 4), dpi=120, gridspec_kw={'width_ratios': [1.15, 3]}) |
| 22 | |
| 23 | heads_num = heads.shape[1] |
| 24 | ax[0].matshow(heads, cmap="custom", vmin=0.0, vmax=1.0) |
| 25 | ax[0].set_xlabel("Heads") |
| 26 | ax[0].set_ylabel("Layer") |
| 27 | ax[0].set_xticks([i for i in range(heads_num)], [str(i + 1) |
| 28 | for i in range(heads_num)]) |
| 29 | ax[0].set_yticks([i for i in range(12)], [str(i + 1) for i in range(12)]) |
| 30 | # Minor ticks |
| 31 | ax[0].set_xticks([i - 0.5 for i in range(heads_num)], minor=True) |
| 32 | ax[0].set_yticks([i - 0.5 for i in range(12)], minor=True) |
| 33 | ax[0].xaxis.tick_bottom() |
| 34 | ax[0].tick_params('both', length=0, width=0, which='both') |
| 35 | |
| 36 | # Gridlines based on minor ticks |
| 37 | ax[0].grid(which='minor', color='w', linestyle='-', linewidth=1) |
| 38 | ax[0].set_title('MHAs') |
| 39 | |
| 40 | channel = intermediates.shape[1] / 4 |
| 41 | intermediates = intermediates.repeat(100, axis=0) |
| 42 | ax[1].matshow(intermediates, cmap="custom", vmin=0.0, vmax=1.0) |
| 43 | ax[1].set_xlabel("FFNs channels") |
| 44 | |
| 45 | ax[1].set_xticks([i * channel for i in range(1, 5)], |
| 46 | [f'{i}.0x' for i in range(1, 5)]) |
| 47 | ax[1].set_yticks([i * 100 + 50 for i in range(12)], |
| 48 | [str(i + 1) for i in range(12)]) |
| 49 | ax[1].set_yticks([i * 100 for i in range(12)], minor=True) |
| 50 | |
| 51 | # Minor ticks |
| 52 | |
| 53 | ax[1].xaxis.tick_bottom() |
| 54 | ax[1].yaxis.tick_right() |
| 55 | |
| 56 | ax[1].tick_params('both', length=0, width=0, which='both') |
| 57 | |
| 58 | # Gridlines based on minor ticks |
| 59 | ax[1].grid(which='minor', axis='y', color='w', linestyle='-', linewidth=1) |
| 60 | ax[1].set_title('FFNs') |
| 61 | |
| 62 | fig.tight_layout() |
| 63 | |
| 64 | fig.suptitle(name) |
| 65 | |
| 66 | return fig |