| 810 | self.quant_mode = quant_mode |
| 811 | |
| 812 | def forward(self, x): |
| 813 | if not USE_STATIC_SHAPE: |
| 814 | raise NotImplementedError('Only static shape is supported') |
| 815 | _, _, D, H, W = x.shape |
| 816 | if W % self.patch_size[2] != 0: |
| 817 | x = pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) |
| 818 | if H % self.patch_size[1] != 0: |
| 819 | x = pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) |
| 820 | if D % self.patch_size[0] != 0: |
| 821 | x = pad( |
| 822 | x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0])) |
| 823 | x = self.proj(x) # (B C T H W) |
| 824 | if self.norm is not None: |
| 825 | D = shape(x, 2) |
| 826 | Wh = shape(x, 3) |
| 827 | Ww = shape(x, 4) |
| 828 | x = x.flatten(2).transpose(1, 2) |
| 829 | x = self.norm(x) |
| 830 | x = x.transpose(1, 2).view([-1, self.embed_dim, D, Wh, Ww]) |
| 831 | if self.flatten: |
| 832 | x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC |
| 833 | self.register_network_output('output', x) |
| 834 | return x |
| 835 | |
| 836 | |
| 837 | class STDiT3Block(Module): |