| 312 | |
| 313 | class InvConv(nn.Module): |
| 314 | def __init__(self, channels, no_jacobian=False, lu=True, **kwargs): |
| 315 | super().__init__() |
| 316 | w_shape = [channels, channels] |
| 317 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float) |
| 318 | LU_decomposed = lu |
| 319 | if not LU_decomposed: |
| 320 | # Sample a random orthogonal matrix: |
| 321 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) |
| 322 | else: |
| 323 | np_p, np_l, np_u = linalg.lu(w_init) |
| 324 | np_s = np.diag(np_u) |
| 325 | np_sign_s = np.sign(np_s) |
| 326 | np_log_s = np.log(np.abs(np_s)) |
| 327 | np_u = np.triu(np_u, k=1) |
| 328 | l_mask = np.tril(np.ones(w_shape, dtype=float), -1) |
| 329 | eye = np.eye(*w_shape, dtype=float) |
| 330 | |
| 331 | self.register_buffer('p', torch.Tensor(np_p.astype(float))) |
| 332 | self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float))) |
| 333 | self.l = nn.Parameter(torch.Tensor(np_l.astype(float))) |
| 334 | self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float))) |
| 335 | self.u = nn.Parameter(torch.Tensor(np_u.astype(float))) |
| 336 | self.l_mask = torch.Tensor(l_mask) |
| 337 | self.eye = torch.Tensor(eye) |
| 338 | self.w_shape = w_shape |
| 339 | self.LU = LU_decomposed |
| 340 | self.weight = None |
| 341 | |
| 342 | def get_weight(self, device, reverse): |
| 343 | w_shape = self.w_shape |