r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints
(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
)
| 33 | class WanI2V: |
| 34 | |
| 35 | def __init__( |
| 36 | self, |
| 37 | config, |
| 38 | checkpoint_dir, |
| 39 | device_id=0, |
| 40 | rank=0, |
| 41 | t5_fsdp=False, |
| 42 | dit_fsdp=False, |
| 43 | use_sp=False, |
| 44 | t5_cpu=False, |
| 45 | init_on_cpu=True, |
| 46 | convert_model_dtype=False, |
| 47 | ): |
| 48 | r""" |
| 49 | Initializes the image-to-video generation model components. |
| 50 | |
| 51 | Args: |
| 52 | config (EasyDict): |
| 53 | Object containing model parameters initialized from config.py |
| 54 | checkpoint_dir (`str`): |
| 55 | Path to directory containing model checkpoints |
| 56 | device_id (`int`, *optional*, defaults to 0): |
| 57 | Id of target GPU device |
| 58 | rank (`int`, *optional*, defaults to 0): |
| 59 | Process rank for distributed training |
| 60 | t5_fsdp (`bool`, *optional*, defaults to False): |
| 61 | Enable FSDP sharding for T5 model |
| 62 | dit_fsdp (`bool`, *optional*, defaults to False): |
| 63 | Enable FSDP sharding for DiT model |
| 64 | use_sp (`bool`, *optional*, defaults to False): |
| 65 | Enable distribution strategy of sequence parallel. |
| 66 | t5_cpu (`bool`, *optional*, defaults to False): |
| 67 | Whether to place T5 model on CPU. Only works without t5_fsdp. |
| 68 | init_on_cpu (`bool`, *optional*, defaults to True): |
| 69 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. |
| 70 | convert_model_dtype (`bool`, *optional*, defaults to False): |
| 71 | Convert DiT model parameters dtype to 'config.param_dtype'. |
| 72 | Only works without FSDP. |
| 73 | """ |
| 74 | self.device = torch.device(f"cuda:{device_id}") |
| 75 | self.config = config |
| 76 | self.rank = rank |
| 77 | self.t5_cpu = t5_cpu |
| 78 | self.init_on_cpu = init_on_cpu |
| 79 | |
| 80 | self.num_train_timesteps = config.num_train_timesteps |
| 81 | self.boundary = config.boundary |
| 82 | self.param_dtype = config.param_dtype |
| 83 | |
| 84 | if t5_fsdp or dit_fsdp or use_sp: |
| 85 | self.init_on_cpu = False |
| 86 | |
| 87 | shard_fn = partial(shard_model, device_id=device_id) |
| 88 | self.text_encoder = T5EncoderModel( |
| 89 | text_len=config.text_len, |
| 90 | dtype=config.t5_dtype, |
| 91 | device=torch.device('cpu'), |
| 92 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), |
nothing calls this directly
no test coverage detected