(n_layers)
| 120 | return x, w1, w2 |
| 121 | |
| 122 | def run_layers(n_layers): |
| 123 | GlobalCounters.reset() |
| 124 | |
| 125 | @TinyJit |
| 126 | def f(x, w1, w2): |
| 127 | for _ in range(n_layers): |
| 128 | x = (x @ w1 @ w2) |
| 129 | return x.contiguous() |
| 130 | |
| 131 | for _ in range(3): |
| 132 | a = make_inp() |
| 133 | r = f(*a) |
| 134 | del a, r |
| 135 | |
| 136 | gc.collect() |
| 137 | return GlobalCounters.mem_used |
| 138 | |
| 139 | mem_2 = run_layers(2) |
| 140 | mem_4 = run_layers(4) |