MCPcopy Index your code
hub / github.com/huggingface/diffusers / cast_training_params

Function cast_training_params

src/diffusers/training_utils.py:316–330  ·  view source on GitHub ↗

Casts the training parameters of the model to the specified data type. Args: model: The PyTorch model whose parameters will be cast. dtype: The data type to which the model parameters will be cast.

(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32)

Source from the content-addressed store, hash-verified

314
315
316def cast_training_params(model: torch.nn.Module | list[torch.nn.Module], dtype=torch.float32):
317 """
318 Casts the training parameters of the model to the specified data type.
319
320 Args:
321 model: The PyTorch model whose parameters will be cast.
322 dtype: The data type to which the model parameters will be cast.
323 """
324 if not isinstance(model, list):
325 model = [model]
326 for m in model:
327 for param in m.parameters():
328 # only upcast trainable parameters into fp32
329 if param.requires_grad:
330 param.data = param.to(dtype)
331
332
333def _set_state_dict_into_text_encoder(

Callers 15

load_model_hookFunction · 0.90
mainFunction · 0.90
load_model_hookFunction · 0.90
mainFunction · 0.90
load_model_hookFunction · 0.90
mainFunction · 0.90
load_model_hookFunction · 0.90
mainFunction · 0.90
load_model_hookFunction · 0.90
mainFunction · 0.90
load_model_hookFunction · 0.90
mainFunction · 0.90

Calls 2

parametersMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…