MCPcopy
hub / github.com/Wan-Video/Wan2.2 / _configure_model

Method _configure_model

wan/speech2video.py:146–187  ·  view source on GitHub ↗

Configures a model object. This includes setting evaluation modes, applying distributed parallel strategy, and handling device placement. Args: model (torch.nn.Module): The model instance to configure. use_sp (`bool`):

(self, model, use_sp, dit_fsdp, shard_fn,
                         convert_model_dtype)

Source from the content-addressed store, hash-verified

144 self.audio_sample_m = 0
145
146 def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
147 convert_model_dtype):
148 """
149 Configures a model object. This includes setting evaluation modes,
150 applying distributed parallel strategy, and handling device placement.
151
152 Args:
153 model (torch.nn.Module):
154 The model instance to configure.
155 use_sp (`bool`):
156 Enable distribution strategy of sequence parallel.
157 dit_fsdp (`bool`):
158 Enable FSDP sharding for DiT model.
159 shard_fn (callable):
160 The function to apply FSDP sharding.
161 convert_model_dtype (`bool`):
162 Convert DiT model parameters dtype to 'config.param_dtype'.
163 Only works without FSDP.
164
165 Returns:
166 torch.nn.Module:
167 The configured model.
168 """
169 model.eval().requires_grad_(False)
170 if use_sp:
171 for block in model.blocks:
172 block.self_attn.forward = types.MethodType(
173 sp_attn_forward_s2v, block.self_attn)
174 model.use_context_parallel = True
175
176 if dist.is_initialized():
177 dist.barrier()
178
179 if dit_fsdp:
180 model = shard_fn(model)
181 else:
182 if convert_model_dtype:
183 model.to(self.param_dtype)
184 if not self.init_on_cpu:
185 model.to(self.device)
186
187 return model
188
189 def get_size_less_than_area(self,
190 height,

Callers 1

__init__Method · 0.95

Calls 1

toMethod · 0.80

Tested by

no test coverage detected