Casts the training parameters of the model to the specified data type. Args: model: The PyTorch model whose parameters will be cast. dtype: The data type to which the model parameters will be cast.
(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32)
| 314 | |
| 315 | |
| 316 | def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32): |
| 317 | """ |
| 318 | Casts the training parameters of the model to the specified data type. |
| 319 | |
| 320 | Args: |
| 321 | model: The PyTorch model whose parameters will be cast. |
| 322 | dtype: The data type to which the model parameters will be cast. |
| 323 | """ |
| 324 | if not isinstance(model, list): |
| 325 | model = [model] |
| 326 | for m in model: |
| 327 | for param in m.parameters(): |
| 328 | # only upcast trainable parameters into fp32 |
| 329 | if param.requires_grad: |
| 330 | param.data = param.to(dtype) |
| 331 | |
| 332 | |
| 333 | def _set_state_dict_into_text_encoder( |
no test coverage detected
searching dependent graphs…