r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints
(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
)
| 44 | class WanI2V: |
| 45 | |
| 46 | def __init__( |
| 47 | self, |
| 48 | config, |
| 49 | checkpoint_dir, |
| 50 | device_id=0, |
| 51 | rank=0, |
| 52 | t5_fsdp=False, |
| 53 | dit_fsdp=False, |
| 54 | use_sp=False, |
| 55 | t5_cpu=False, |
| 56 | init_on_cpu=True, |
| 57 | convert_model_dtype=False, |
| 58 | ): |
| 59 | r""" |
| 60 | Initializes the image-to-video generation model components. |
| 61 | |
| 62 | Args: |
| 63 | config (EasyDict): |
| 64 | Object containing model parameters initialized from config.py |
| 65 | checkpoint_dir (`str`): |
| 66 | Path to directory containing model checkpoints |
| 67 | device_id (`int`, *optional*, defaults to 0): |
| 68 | Id of target GPU device |
| 69 | rank (`int`, *optional*, defaults to 0): |
| 70 | Process rank for distributed training |
| 71 | t5_fsdp (`bool`, *optional*, defaults to False): |
| 72 | Enable FSDP sharding for T5 model |
| 73 | dit_fsdp (`bool`, *optional*, defaults to False): |
| 74 | Enable FSDP sharding for DiT model |
| 75 | use_sp (`bool`, *optional*, defaults to False): |
| 76 | Enable distribution strategy of sequence parallel. |
| 77 | t5_cpu (`bool`, *optional*, defaults to False): |
| 78 | Whether to place T5 model on CPU. Only works without t5_fsdp. |
| 79 | init_on_cpu (`bool`, *optional*, defaults to True): |
| 80 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. |
| 81 | convert_model_dtype (`bool`, *optional*, defaults to False): |
| 82 | Convert DiT model parameters dtype to 'config.param_dtype'. |
| 83 | Only works without FSDP. |
| 84 | """ |
| 85 | self.device = torch.device(f"cuda:{device_id}") |
| 86 | self.config = config |
| 87 | self.rank = rank |
| 88 | self.t5_cpu = t5_cpu |
| 89 | self.init_on_cpu = init_on_cpu |
| 90 | |
| 91 | self.num_train_timesteps = config.num_train_timesteps |
| 92 | self.boundary = config.boundary |
| 93 | self.param_dtype = config.param_dtype |
| 94 | |
| 95 | if t5_fsdp or dit_fsdp or use_sp: |
| 96 | self.init_on_cpu = False |
| 97 | |
| 98 | if 'cam' in checkpoint_dir: |
| 99 | self.control_type = 'cam' |
| 100 | elif 'act' in checkpoint_dir: |
| 101 | self.control_type = 'act' |
| 102 | |
| 103 | shard_fn = partial(shard_model, device_id=device_id) |
nothing calls this directly
no test coverage detected