(self, x, x_mask=None, reverse=False, **kwargs)
| 192 | self.bias = nn.Parameter(torch.zeros(1, channels, 1)) |
| 193 | |
| 194 | def forward(self, x, x_mask=None, reverse=False, **kwargs): |
| 195 | if x_mask is None: |
| 196 | x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) |
| 197 | x_len = torch.sum(x_mask, [1, 2]) |
| 198 | if not self.initialized: |
| 199 | self.initialize(x, x_mask) |
| 200 | self.initialized = True |
| 201 | |
| 202 | if reverse: |
| 203 | z = (x - self.bias) * torch.exp(-self.logs) * x_mask |
| 204 | logdet = torch.sum(-self.logs) * x_len |
| 205 | else: |
| 206 | z = (self.bias + torch.exp(self.logs) * x) * x_mask |
| 207 | logdet = torch.sum(self.logs) * x_len # [b] |
| 208 | return z, logdet |
| 209 | |
| 210 | def store_inverse(self): |
| 211 | pass |
nothing calls this directly
no test coverage detected