(self, x)
| 35 | self.fc_weight = nn.Linear(c.dim, c.num_components) |
| 36 | |
| 37 | def _square_plus(self, x): |
| 38 | return (x + T.sqrt(T.square(x) + 4)) / 2 |
| 39 | |
| 40 | def input(self, sampled_latents: T.Tensor) -> T.Tensor: |
| 41 | """Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)""" |