MCPcopy
hub / github.com/microsoft/Cream / plot

Function plot

TinyCLIP/src/training/viz.py:19–66  ·  view source on GitHub ↗
(heads, intermediates, name)

Source from the content-addressed store, hash-verified

17
18
19def plot(heads, intermediates, name):
20 fig, ax = plt.subplots(1, 2, facecolor='white', figsize=(
21 10, 4), dpi=120, gridspec_kw={'width_ratios': [1.15, 3]})
22
23 heads_num = heads.shape[1]
24 ax[0].matshow(heads, cmap="custom", vmin=0.0, vmax=1.0)
25 ax[0].set_xlabel("Heads")
26 ax[0].set_ylabel("Layer")
27 ax[0].set_xticks([i for i in range(heads_num)], [str(i + 1)
28 for i in range(heads_num)])
29 ax[0].set_yticks([i for i in range(12)], [str(i + 1) for i in range(12)])
30 # Minor ticks
31 ax[0].set_xticks([i - 0.5 for i in range(heads_num)], minor=True)
32 ax[0].set_yticks([i - 0.5 for i in range(12)], minor=True)
33 ax[0].xaxis.tick_bottom()
34 ax[0].tick_params('both', length=0, width=0, which='both')
35
36 # Gridlines based on minor ticks
37 ax[0].grid(which='minor', color='w', linestyle='-', linewidth=1)
38 ax[0].set_title('MHAs')
39
40 channel = intermediates.shape[1] / 4
41 intermediates = intermediates.repeat(100, axis=0)
42 ax[1].matshow(intermediates, cmap="custom", vmin=0.0, vmax=1.0)
43 ax[1].set_xlabel("FFNs channels")
44
45 ax[1].set_xticks([i * channel for i in range(1, 5)],
46 [f'{i}.0x' for i in range(1, 5)])
47 ax[1].set_yticks([i * 100 + 50 for i in range(12)],
48 [str(i + 1) for i in range(12)])
49 ax[1].set_yticks([i * 100 for i in range(12)], minor=True)
50
51 # Minor ticks
52
53 ax[1].xaxis.tick_bottom()
54 ax[1].yaxis.tick_right()
55
56 ax[1].tick_params('both', length=0, width=0, which='both')
57
58 # Gridlines based on minor ticks
59 ax[1].grid(which='minor', axis='y', color='w', linestyle='-', linewidth=1)
60 ax[1].set_title('FFNs')
61
62 fig.tight_layout()
63
64 fig.suptitle(name)
65
66 return fig

Callers 1

train_one_epochFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected