float tensor (B, L, D) -> Tuple of locs, scales, and weights
(self, h: T.Tensor)
| 43 | return hidden |
| 44 | |
| 45 | def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]: |
| 46 | """float tensor (B, L, D) -> Tuple of locs, scales, and weights""" |
| 47 | batch_size, seq_len, _ = h.shape |
| 48 | |
| 49 | locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim) |
| 50 | scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim) |
| 51 | weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components) |
| 52 | |
| 53 | return (locs, scales, weights) |
| 54 | |
| 55 | def loss(self, data, dataHat): |
| 56 | locs, scales, weights = dataHat |
no test coverage detected