MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / __init__

Method __init__

wan/image2video.py:34–131  ·  view source on GitHub ↗

r""" Initializes the image-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints

(
        self,
        config,
        checkpoint_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
        t5_cpu=False,
        init_on_cpu=True,
    )

Source from the content-addressed store, hash-verified

32class WanI2V:
33
34 def __init__(
35 self,
36 config,
37 checkpoint_dir,
38 device_id=0,
39 rank=0,
40 t5_fsdp=False,
41 dit_fsdp=False,
42 use_usp=False,
43 t5_cpu=False,
44 init_on_cpu=True,
45 ):
46 r"""
47 Initializes the image-to-video generation model components.
48
49 Args:
50 config (EasyDict):
51 Object containing model parameters initialized from config.py
52 checkpoint_dir (`str`):
53 Path to directory containing model checkpoints
54 device_id (`int`, *optional*, defaults to 0):
55 Id of target GPU device
56 rank (`int`, *optional*, defaults to 0):
57 Process rank for distributed training
58 t5_fsdp (`bool`, *optional*, defaults to False):
59 Enable FSDP sharding for T5 model
60 dit_fsdp (`bool`, *optional*, defaults to False):
61 Enable FSDP sharding for DiT model
62 use_usp (`bool`, *optional*, defaults to False):
63 Enable distribution strategy of USP.
64 t5_cpu (`bool`, *optional*, defaults to False):
65 Whether to place T5 model on CPU. Only works without t5_fsdp.
66 init_on_cpu (`bool`, *optional*, defaults to True):
67 Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
68 """
69 self.device = torch.device(f"cuda:{device_id}")
70 self.config = config
71 self.rank = rank
72 self.use_usp = use_usp
73 self.t5_cpu = t5_cpu
74
75 self.num_train_timesteps = config.num_train_timesteps
76 self.param_dtype = config.param_dtype
77
78 shard_fn = partial(shard_model, device_id=device_id)
79 self.text_encoder = T5EncoderModel(
80 text_len=config.text_len,
81 dtype=config.t5_dtype,
82 device=torch.device('cpu'),
83 checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
84 tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
85 shard_fn=shard_fn if t5_fsdp else None,
86 )
87
88 self.vae_stride = config.vae_stride
89 self.patch_size = config.patch_size
90 self.vae = WanVAE(
91 vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),

Callers

nothing calls this directly

Calls 4

T5EncoderModelClass · 0.85
WanVAEClass · 0.85
CLIPModelClass · 0.85
deviceMethod · 0.80

Tested by

no test coverage detected