| 214 | self.initialized = not ddi |
| 215 | |
| 216 | def initialize(self, x, x_mask): |
| 217 | with torch.no_grad(): |
| 218 | denom = torch.sum(x_mask, [0, 2]) |
| 219 | m = torch.sum(x * x_mask, [0, 2]) / denom |
| 220 | m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom |
| 221 | v = m_sq - (m ** 2) |
| 222 | logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) |
| 223 | |
| 224 | bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) |
| 225 | logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) |
| 226 | |
| 227 | self.bias.data.copy_(bias_init) |
| 228 | self.logs.data.copy_(logs_init) |
| 229 | |
| 230 | |
| 231 | class InvConvNear(nn.Module): |