r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints
(
self,
config,
checkpoint_dir,
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=True,
convert_model_dtype=False,
pipe_dtype=torch.bfloat16,
local_attn_size=-1,
sink_size=0,
)
| 37 | class WanI2VFast: |
| 38 | |
| 39 | def __init__( |
| 40 | self, |
| 41 | config, |
| 42 | checkpoint_dir, |
| 43 | device_id=0, |
| 44 | rank=0, |
| 45 | t5_fsdp=False, |
| 46 | dit_fsdp=False, |
| 47 | use_sp=False, |
| 48 | t5_cpu=False, |
| 49 | init_on_cpu=True, |
| 50 | convert_model_dtype=False, |
| 51 | pipe_dtype=torch.bfloat16, |
| 52 | local_attn_size=-1, |
| 53 | sink_size=0, |
| 54 | ): |
| 55 | r""" |
| 56 | Initializes the image-to-video generation model components. |
| 57 | |
| 58 | Args: |
| 59 | config (EasyDict): |
| 60 | Object containing model parameters initialized from config.py |
| 61 | checkpoint_dir (`str`): |
| 62 | Path to directory containing model checkpoints |
| 63 | device_id (`int`, *optional*, defaults to 0): |
| 64 | Id of target GPU device |
| 65 | rank (`int`, *optional*, defaults to 0): |
| 66 | Process rank for distributed training |
| 67 | t5_fsdp (`bool`, *optional*, defaults to False): |
| 68 | Enable FSDP sharding for T5 model |
| 69 | dit_fsdp (`bool`, *optional*, defaults to False): |
| 70 | Enable FSDP sharding for DiT model |
| 71 | use_sp (`bool`, *optional*, defaults to False): |
| 72 | Enable distribution strategy of sequence parallel. |
| 73 | t5_cpu (`bool`, *optional*, defaults to False): |
| 74 | Whether to place T5 model on CPU. Only works without t5_fsdp. |
| 75 | init_on_cpu (`bool`, *optional*, defaults to True): |
| 76 | Enable initializing Transformer Model on CPU. Only works without FSDP or USP. |
| 77 | convert_model_dtype (`bool`, *optional*, defaults to False): |
| 78 | Convert DiT model parameters dtype to 'config.param_dtype'. |
| 79 | Only works without FSDP. |
| 80 | """ |
| 81 | self.device = torch.device(f"cuda:{device_id}") |
| 82 | self.config = config |
| 83 | self.rank = rank |
| 84 | self.t5_cpu = t5_cpu |
| 85 | self.init_on_cpu = init_on_cpu |
| 86 | |
| 87 | self.num_train_timesteps = config.num_train_timesteps |
| 88 | self.boundary = config.boundary |
| 89 | self.param_dtype = config.param_dtype |
| 90 | self.pipe_dtype = pipe_dtype |
| 91 | self.local_attn_size = local_attn_size |
| 92 | self.sink_size = sink_size |
| 93 | |
| 94 | if t5_fsdp or dit_fsdp or use_sp: |
| 95 | self.init_on_cpu = False |
| 96 |
nothing calls this directly
no test coverage detected