MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / __init__

Method __init__

wan/vace.py:39–137  ·  view source on GitHub ↗

r""" Initializes the Wan text-to-video generation model components. Args: config (EasyDict): Object containing model parameters initialized from config.py checkpoint_dir (`str`): Path to directory containing model checkpoints

(
        self,
        config,
        checkpoint_dir,
        device_id=0,
        rank=0,
        t5_fsdp=False,
        dit_fsdp=False,
        use_usp=False,
        t5_cpu=False,
    )

Source from the content-addressed store, hash-verified

37class WanVace(WanT2V):
38
39 def __init__(
40 self,
41 config,
42 checkpoint_dir,
43 device_id=0,
44 rank=0,
45 t5_fsdp=False,
46 dit_fsdp=False,
47 use_usp=False,
48 t5_cpu=False,
49 ):
50 r"""
51 Initializes the Wan text-to-video generation model components.
52
53 Args:
54 config (EasyDict):
55 Object containing model parameters initialized from config.py
56 checkpoint_dir (`str`):
57 Path to directory containing model checkpoints
58 device_id (`int`, *optional*, defaults to 0):
59 Id of target GPU device
60 rank (`int`, *optional*, defaults to 0):
61 Process rank for distributed training
62 t5_fsdp (`bool`, *optional*, defaults to False):
63 Enable FSDP sharding for T5 model
64 dit_fsdp (`bool`, *optional*, defaults to False):
65 Enable FSDP sharding for DiT model
66 use_usp (`bool`, *optional*, defaults to False):
67 Enable distribution strategy of USP.
68 t5_cpu (`bool`, *optional*, defaults to False):
69 Whether to place T5 model on CPU. Only works without t5_fsdp.
70 """
71 self.device = torch.device(f"cuda:{device_id}")
72 self.config = config
73 self.rank = rank
74 self.t5_cpu = t5_cpu
75
76 self.num_train_timesteps = config.num_train_timesteps
77 self.param_dtype = config.param_dtype
78
79 shard_fn = partial(shard_model, device_id=device_id)
80 self.text_encoder = T5EncoderModel(
81 text_len=config.text_len,
82 dtype=config.t5_dtype,
83 device=torch.device('cpu'),
84 checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
85 tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
86 shard_fn=shard_fn if t5_fsdp else None)
87
88 self.vae_stride = config.vae_stride
89 self.patch_size = config.patch_size
90 self.vae = WanVAE(
91 vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
92 device=self.device)
93
94 logging.info(f"Creating VaceWanModel from {checkpoint_dir}")
95 self.model = VaceWanModel.from_pretrained(checkpoint_dir)
96 self.model.eval().requires_grad_(False)

Callers

nothing calls this directly

Calls 4

T5EncoderModelClass · 0.85
WanVAEClass · 0.85
VaceVideoProcessorClass · 0.85
deviceMethod · 0.80

Tested by

no test coverage detected