(
self,
*,
image_size,
layers: Tuple[Union[str, Tuple[str, int]], ...] = ("residual", "residual", "residual"),
residual_conv_kernel_size=3,
num_codebooks=1,
codebook_size: Optional[int] = None,
channels=3,
init_dim=64,
max_dim=float("inf"),
dim_cond=None,
dim_cond_expansion_factor=4.0,
input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7),
output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3),
pad_mode: str = "constant",
lfq_entropy_loss_weight=0.1,
lfq_commitment_loss_weight=1.0,
lfq_diversity_gamma=2.5,
quantizer_aux_loss_weight=1.0,
lfq_activation=nn.Identity(),
use_fsq=False,
fsq_levels: Optional[List[int]] = None,
attn_dim_head=32,
attn_heads=8,
attn_dropout=0.0,
linear_attn_dim_head=8,
linear_attn_heads=16,
vgg: Optional[Module] = None,
vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT,
perceptual_loss_weight=1e-1,
discr_kwargs: Optional[dict] = None,
multiscale_discrs: Tuple[Module, ...] = tuple(),
use_gan=True,
adversarial_loss_weight=1.0,
grad_penalty_loss_weight=10.0,
multiscale_adversarial_loss_weight=1.0,
flash_attn=True,
separate_first_frame_encoding=False,
)
| 942 | class VideoTokenizer(Module): |
| 943 | @beartype |
| 944 | def __init__( |
| 945 | self, |
| 946 | *, |
| 947 | image_size, |
| 948 | layers: Tuple[Union[str, Tuple[str, int]], ...] = ("residual", "residual", "residual"), |
| 949 | residual_conv_kernel_size=3, |
| 950 | num_codebooks=1, |
| 951 | codebook_size: Optional[int] = None, |
| 952 | channels=3, |
| 953 | init_dim=64, |
| 954 | max_dim=float("inf"), |
| 955 | dim_cond=None, |
| 956 | dim_cond_expansion_factor=4.0, |
| 957 | input_conv_kernel_size: Tuple[int, int, int] = (7, 7, 7), |
| 958 | output_conv_kernel_size: Tuple[int, int, int] = (3, 3, 3), |
| 959 | pad_mode: str = "constant", |
| 960 | lfq_entropy_loss_weight=0.1, |
| 961 | lfq_commitment_loss_weight=1.0, |
| 962 | lfq_diversity_gamma=2.5, |
| 963 | quantizer_aux_loss_weight=1.0, |
| 964 | lfq_activation=nn.Identity(), |
| 965 | use_fsq=False, |
| 966 | fsq_levels: Optional[List[int]] = None, |
| 967 | attn_dim_head=32, |
| 968 | attn_heads=8, |
| 969 | attn_dropout=0.0, |
| 970 | linear_attn_dim_head=8, |
| 971 | linear_attn_heads=16, |
| 972 | vgg: Optional[Module] = None, |
| 973 | vgg_weights: VGG16_Weights = VGG16_Weights.DEFAULT, |
| 974 | perceptual_loss_weight=1e-1, |
| 975 | discr_kwargs: Optional[dict] = None, |
| 976 | multiscale_discrs: Tuple[Module, ...] = tuple(), |
| 977 | use_gan=True, |
| 978 | adversarial_loss_weight=1.0, |
| 979 | grad_penalty_loss_weight=10.0, |
| 980 | multiscale_adversarial_loss_weight=1.0, |
| 981 | flash_attn=True, |
| 982 | separate_first_frame_encoding=False, |
| 983 | ): |
| 984 | super().__init__() |
| 985 | |
| 986 | # for autosaving the config |
| 987 | |
| 988 | _locals = locals() |
| 989 | _locals.pop("self", None) |
| 990 | _locals.pop("__class__", None) |
| 991 | self._configs = pickle.dumps(_locals) |
| 992 | |
| 993 | # image size |
| 994 | |
| 995 | self.channels = channels |
| 996 | self.image_size = image_size |
| 997 | |
| 998 | # initial encoder |
| 999 | |
| 1000 | self.conv_in = CausalConv3d(channels, init_dim, input_conv_kernel_size, pad_mode=pad_mode) |
| 1001 |
nothing calls this directly
no test coverage detected