(a, t, x_shape)
| 94 | |
| 95 | |
| 96 | def extract_into_tensor(a, t, x_shape): |
| 97 | b, *_ = t.shape |
| 98 | out = a.gather(-1, t) |
| 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
| 100 | |
| 101 | |
| 102 | def checkpoint(func, inputs, params, flag): |
no outgoing calls
no test coverage detected