(self, x, scale)
| 781 | return x_recon, mu |
| 782 | |
| 783 | def encode(self, x, scale): |
| 784 | self.clear_cache() |
| 785 | x = patchify(x, patch_size=2) |
| 786 | t = x.shape[2] |
| 787 | iter_ = 1 + (t - 1) // 4 |
| 788 | for i in range(iter_): |
| 789 | self._enc_conv_idx = [0] |
| 790 | if i == 0: |
| 791 | out = self.encoder( |
| 792 | x[:, :, :1, :, :], |
| 793 | feat_cache=self._enc_feat_map, |
| 794 | feat_idx=self._enc_conv_idx, |
| 795 | ) |
| 796 | else: |
| 797 | out_ = self.encoder( |
| 798 | x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], |
| 799 | feat_cache=self._enc_feat_map, |
| 800 | feat_idx=self._enc_conv_idx, |
| 801 | ) |
| 802 | out = torch.cat([out, out_], 2) |
| 803 | mu, log_var = self.conv1(out).chunk(2, dim=1) |
| 804 | if isinstance(scale[0], torch.Tensor): |
| 805 | mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( |
| 806 | 1, self.z_dim, 1, 1, 1) |
| 807 | else: |
| 808 | mu = (mu - scale[0]) * scale[1] |
| 809 | self.clear_cache() |
| 810 | return mu |
| 811 | |
| 812 | def decode(self, z, scale): |
| 813 | self.clear_cache() |
no test coverage detected