Configures a model object. This includes setting evaluation modes, applying distributed parallel strategy, and handling device placement. Args: model (torch.nn.Module): The model instance to configure. use_sp (`bool`):
(self, model, use_sp, dit_fsdp, shard_fn,
convert_model_dtype)
| 144 | self.audio_sample_m = 0 |
| 145 | |
| 146 | def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, |
| 147 | convert_model_dtype): |
| 148 | """ |
| 149 | Configures a model object. This includes setting evaluation modes, |
| 150 | applying distributed parallel strategy, and handling device placement. |
| 151 | |
| 152 | Args: |
| 153 | model (torch.nn.Module): |
| 154 | The model instance to configure. |
| 155 | use_sp (`bool`): |
| 156 | Enable distribution strategy of sequence parallel. |
| 157 | dit_fsdp (`bool`): |
| 158 | Enable FSDP sharding for DiT model. |
| 159 | shard_fn (callable): |
| 160 | The function to apply FSDP sharding. |
| 161 | convert_model_dtype (`bool`): |
| 162 | Convert DiT model parameters dtype to 'config.param_dtype'. |
| 163 | Only works without FSDP. |
| 164 | |
| 165 | Returns: |
| 166 | torch.nn.Module: |
| 167 | The configured model. |
| 168 | """ |
| 169 | model.eval().requires_grad_(False) |
| 170 | if use_sp: |
| 171 | for block in model.blocks: |
| 172 | block.self_attn.forward = types.MethodType( |
| 173 | sp_attn_forward_s2v, block.self_attn) |
| 174 | model.use_context_parallel = True |
| 175 | |
| 176 | if dist.is_initialized(): |
| 177 | dist.barrier() |
| 178 | |
| 179 | if dit_fsdp: |
| 180 | model = shard_fn(model) |
| 181 | else: |
| 182 | if convert_model_dtype: |
| 183 | model.to(self.param_dtype) |
| 184 | if not self.init_on_cpu: |
| 185 | model.to(self.device) |
| 186 | |
| 187 | return model |
| 188 | |
| 189 | def get_size_less_than_area(self, |
| 190 | height, |