r""" Initializes the Wan text-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints
(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
)
| 37 | class WanVace(WanT2V): |
| 38 | |
| 39 | def __init__( |
| 40 | self, |
| 41 | config, |
| 42 | checkpoint_dir, |
| 43 | device_id=0, |
| 44 | rank=0, |
| 45 | t5_fsdp=False, |
| 46 | dit_fsdp=False, |
| 47 | use_usp=False, |
| 48 | t5_cpu=False, |
| 49 | ): |
| 50 | r""" |
| 51 | Initializes the Wan text-to-video generation model components. |
| 52 | |
| 53 | Args: |
| 54 | config (EasyDict): |
| 55 | Object containing model parameters initialized from config.py |
| 56 | checkpoint_dir (`str`): |
| 57 | Path to directory containing model checkpoints |
| 58 | device_id (`int`, *optional*, defaults to 0): |
| 59 | Id of target GPU device |
| 60 | rank (`int`, *optional*, defaults to 0): |
| 61 | Process rank for distributed training |
| 62 | t5_fsdp (`bool`, *optional*, defaults to False): |
| 63 | Enable FSDP sharding for T5 model |
| 64 | dit_fsdp (`bool`, *optional*, defaults to False): |
| 65 | Enable FSDP sharding for DiT model |
| 66 | use_usp (`bool`, *optional*, defaults to False): |
| 67 | Enable distribution strategy of USP. |
| 68 | t5_cpu (`bool`, *optional*, defaults to False): |
| 69 | Whether to place T5 model on CPU. Only works without t5_fsdp. |
| 70 | """ |
| 71 | self.device = torch.device(f"cuda:{device_id}") |
| 72 | self.config = config |
| 73 | self.rank = rank |
| 74 | self.t5_cpu = t5_cpu |
| 75 | |
| 76 | self.num_train_timesteps = config.num_train_timesteps |
| 77 | self.param_dtype = config.param_dtype |
| 78 | |
| 79 | shard_fn = partial(shard_model, device_id=device_id) |
| 80 | self.text_encoder = T5EncoderModel( |
| 81 | text_len=config.text_len, |
| 82 | dtype=config.t5_dtype, |
| 83 | device=torch.device('cpu'), |
| 84 | checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), |
| 85 | tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), |
| 86 | shard_fn=shard_fn if t5_fsdp else None) |
| 87 | |
| 88 | self.vae_stride = config.vae_stride |
| 89 | self.patch_size = config.patch_size |
| 90 | self.vae = WanVAE( |
| 91 | vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), |
| 92 | device=self.device) |
| 93 | |
| 94 | logging.info(f"Creating VaceWanModel from {checkpoint_dir}") |
| 95 | self.model = VaceWanModel.from_pretrained(checkpoint_dir) |
| 96 | self.model.eval().requires_grad_(False) |
nothing calls this directly
no test coverage detected