| 230 | |
| 231 | class InvConvNear(nn.Module): |
| 232 | def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs): |
| 233 | super().__init__() |
| 234 | assert (n_split % 2 == 0) |
| 235 | self.channels = channels |
| 236 | self.n_split = n_split |
| 237 | self.n_sqz = n_sqz |
| 238 | self.no_jacobian = no_jacobian |
| 239 | |
| 240 | w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0] |
| 241 | if torch.det(w_init) < 0: |
| 242 | w_init[:, 0] = -1 * w_init[:, 0] |
| 243 | self.lu = lu |
| 244 | if lu: |
| 245 | # LU decomposition can slightly speed up the inverse |
| 246 | np_p, np_l, np_u = linalg.lu(w_init) |
| 247 | np_s = np.diag(np_u) |
| 248 | np_sign_s = np.sign(np_s) |
| 249 | np_log_s = np.log(np.abs(np_s)) |
| 250 | np_u = np.triu(np_u, k=1) |
| 251 | l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1) |
| 252 | eye = np.eye(*w_init.shape, dtype=float) |
| 253 | |
| 254 | self.register_buffer('p', torch.Tensor(np_p.astype(float))) |
| 255 | self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float))) |
| 256 | self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True) |
| 257 | self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True) |
| 258 | self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True) |
| 259 | self.register_buffer('l_mask', torch.Tensor(l_mask)) |
| 260 | self.register_buffer('eye', torch.Tensor(eye)) |
| 261 | else: |
| 262 | self.weight = nn.Parameter(w_init) |
| 263 | |
| 264 | def forward(self, x, x_mask=None, reverse=False, **kwargs): |
| 265 | b, c, t = x.size() |