r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints
(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
init_on_cpu=True,
)
| 32 | class WanI2V: |
| 33 | |
| 34 | def __init__( |
| 35 | self, |
| 36 | config, |
| 37 | checkpoint_dir, |
| 38 | device_id=0, |
| 39 | rank=0, |
| 40 | t5_fsdp=False, |
| 41 | dit_fsdp=False, |
| 42 | use_usp=False, |
| 43 | t5_cpu=False, |
| 44 | init_on_cpu=True, |
| 45 | ): |
| 46 | r""" |
| 47 | Initializes the image-to-video generation model components. |
| 48 | |
| 49 | Args: |
| 50 | config (EasyDict): |
| 51 | Object containing model parameters initialized from config.py |
| 52 | checkpoint_dir (`str`): |
| 53 | Path to directory containing model checkpoints |
| 54 | device_id (`int`, *optional*, defaults to 0): |
| 55 | Id of target GPU device |
| 56 | rank (`int`, *optional*, defaults to 0): |
| 57 | Process rank for distributed training |
| 58 | t5_fsdp (`bool`, *optional*, defaults to False): |
| 59 | Enable FSDP sharding for T5 model |
| 60 | dit_fsdp (`bool`, *optional*, defaults to False): |
| 61 | Enable FSDP sharding for DiT model |
| 62 | use_usp (`bool`, *optional*, defaults to False): |
| 63 | Enable distribution strategy of USP. |
| 64 | t5_cpu (`bool`, *optional*, defaults to False): |
| 65 | Whether to place T5 model on CPU. Only works without t5_fsdp. |
| 66 | init_on_cpu (`bool`, *optional*, defaults to True): |
| 67 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. |
| 68 | """ |
| 69 | self.device = torch.device(f"cuda:{device_id}") |
| 70 | self.config = config |
| 71 | self.rank = rank |
| 72 | self.use_usp = use_usp |
| 73 | self.t5_cpu = t5_cpu |
| 74 | |
| 75 | self.num_train_timesteps = config.num_train_timesteps |
| 76 | self.param_dtype = config.param_dtype |
| 77 | |
| 78 | shard_fn = partial(shard_model, device_id=device_id) |
| 79 | self.text_encoder = T5EncoderModel( |
| 80 | text_len=config.text_len, |
| 81 | dtype=config.t5_dtype, |
| 82 | device=torch.device('cpu'), |
| 83 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), |
| 84 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), |
| 85 | shard_fn=shard_fn if t5_fsdp else None, |
| 86 | ) |
| 87 | |
| 88 | self.vae_stride = config.vae_stride |
| 89 | self.patch_size = config.patch_size |
| 90 | self.vae = WanVAE( |
| 91 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), |
nothing calls this directly
no test coverage detected