MCPcopy
hub / github.com/Wan-Video/Wan2.2 / __init__

Method __init__

wan/image2video.py:35–126  ·  view source on GitHub ↗

r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints

(
        self,
        config,
        checkpoint_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_sp=False,
        t5_cpu=False,
        init_on_cpu=True,
        convert_model_dtype=False,
    )

Source from the content-addressed store, hash-verified

33class WanI2V:
34
35 def __init__(
36 self,
37 config,
38 checkpoint_dir,
39 device_id=0,
40 rank=0,
41 t5_fsdp=False,
42 dit_fsdp=False,
43 use_sp=False,
44 t5_cpu=False,
45 init_on_cpu=True,
46 convert_model_dtype=False,
47 ):
48 r"""
49 Initializes the image-to-video generation model components.
50
51 Args:
52 config (EasyDict):
53 Object containing model parameters initialized from config.py
54 checkpoint_dir (`str`):
55 Path to directory containing model checkpoints
56 device_id (`int`, *optional*, defaults to 0):
57 Id of target GPU device
58 rank (`int`, *optional*, defaults to 0):
59 Process rank for distributed training
60 t5_fsdp (`bool`, *optional*, defaults to False):
61 Enable FSDP sharding for T5 model
62 dit_fsdp (`bool`, *optional*, defaults to False):
63 Enable FSDP sharding for DiT model
64 use_sp (`bool`, *optional*, defaults to False):
65 Enable distribution strategy of sequence parallel.
66 t5_cpu (`bool`, *optional*, defaults to False):
67 Whether to place T5 model on CPU. Only works without t5_fsdp.
68 init_on_cpu (`bool`, *optional*, defaults to True):
69 Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70 convert_model_dtype (`bool`, *optional*, defaults to False):
71 Convert DiT model parameters dtype to 'config.param_dtype'.
72 Only works without FSDP.
73 """
74 self.device = torch.device(f"cuda:{device_id}")
75 self.config = config
76 self.rank = rank
77 self.t5_cpu = t5_cpu
78 self.init_on_cpu = init_on_cpu
79
80 self.num_train_timesteps = config.num_train_timesteps
81 self.boundary = config.boundary
82 self.param_dtype = config.param_dtype
83
84 if t5_fsdp or dit_fsdp or use_sp:
85 self.init_on_cpu = False
86
87 shard_fn = partial(shard_model, device_id=device_id)
88 self.text_encoder = T5EncoderModel(
89 text_len=config.text_len,
90 dtype=config.t5_dtype,
91 device=torch.device('cpu'),
92 checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),

Callers

nothing calls this directly

Calls 5

_configure_modelMethod · 0.95
T5EncoderModelClass · 0.85
Wan2_1_VAEClass · 0.85
get_world_sizeFunction · 0.85
deviceMethod · 0.80

Tested by

no test coverage detected