MCPcopy
hub / github.com/zai-org/CogVideo / ResidualUnitMod

Class ResidualUnitMod

sat/sgm/modules/autoencoding/magvit2_pytorch.py:856–891  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

854
855@beartype
856class ResidualUnitMod(Module):
857 def __init__(
858 self, dim, kernel_size: Union[int, Tuple[int, int, int]], *, dim_cond, pad_mode: str = "constant", demod=True
859 ):
860 super().__init__()
861 kernel_size = cast_tuple(kernel_size, 3)
862 time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
863 assert height_kernel_size == width_kernel_size
864
865 self.to_cond = nn.Linear(dim_cond, dim)
866
867 self.conv = Conv3DMod(
868 dim=dim,
869 spatial_kernel=height_kernel_size,
870 time_kernel=time_kernel_size,
871 causal=True,
872 demod=demod,
873 pad_mode=pad_mode,
874 )
875
876 self.conv_out = nn.Conv3d(dim, dim, 1)
877
878 @beartype
879 def forward(
880 self,
881 x,
882 cond: Tensor,
883 ):
884 res = x
885 cond = self.to_cond(cond)
886
887 x = self.conv(x, cond=cond)
888 x = F.elu(x)
889 x = self.conv_out(x)
890 x = F.elu(x)
891 return x + res
892
893
894class CausalConvTranspose3d(Module):

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected