| 854 | |
| 855 | @beartype |
| 856 | class ResidualUnitMod(Module): |
| 857 | def __init__( |
| 858 | self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True |
| 859 | ): |
| 860 | super().__init__() |
| 861 | kernel_size = cast_tuple(kernel_size, 3) |
| 862 | time_kernel_size, height_kernel_size, width_kernel_size = kernel_size |
| 863 | assert height_kernel_size == width_kernel_size |
| 864 | |
| 865 | self.to_cond = nn.Linear(dim_cond, dim) |
| 866 | |
| 867 | self.conv = Conv3DMod( |
| 868 | dim=dim, |
| 869 | spatial_kernel=height_kernel_size, |
| 870 | time_kernel=time_kernel_size, |
| 871 | causal=True, |
| 872 | demod=demod, |
| 873 | pad_mode=pad_mode, |
| 874 | ) |
| 875 | |
| 876 | self.conv_out = nn.Conv3d(dim, dim, 1) |
| 877 | |
| 878 | @beartype |
| 879 | def forward( |
| 880 | self, |
| 881 | x, |
| 882 | cond: Tensor, |
| 883 | ): |
| 884 | res = x |
| 885 | cond = self.to_cond(cond) |
| 886 | |
| 887 | x = self.conv(x, cond=cond) |
| 888 | x = F.elu(x) |
| 889 | x = self.conv_out(x) |
| 890 | x = F.elu(x) |
| 891 | return x + res |
| 892 | |
| 893 | |
| 894 | class CausalConvTranspose3d(Module): |