(a, t, x_shape)
| 35 | |
| 36 | |
| 37 | def extract_into_tensor(a, t, x_shape): |
| 38 | b, *_ = t.shape |
| 39 | out = a.gather(-1, t) |
| 40 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
| 41 | |
| 42 | |
| 43 | class MVDiffusion(pl.LightningModule): |
no outgoing calls
no test coverage detected