(in_channels, motion_module_type: str, motion_module_kwargs: dict)
| 32 | |
| 33 | |
| 34 | def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict): |
| 35 | if motion_module_type == "Vanilla": |
| 36 | return VanillaTemporalModule( |
| 37 | in_channels=in_channels, |
| 38 | **motion_module_kwargs, |
| 39 | ) |
| 40 | else: |
| 41 | raise ValueError |
| 42 | |
| 43 | |
| 44 | class VanillaTemporalModule(nn.Module): |