MCPcopy Index your code
hub / github.com/huggingface/diffusers / _maybe_cast_to_bf16

Function _maybe_cast_to_bf16

tests/models/testing_utils/utils.py:13–22  ·  view source on GitHub ↗

Cast model and floating-point inputs to bfloat16 when the backend requires it.

(backend, model, inputs_dict)

Source from the content-addressed store, hash-verified

11
12
13def _maybe_cast_to_bf16(backend, model, inputs_dict):
14 """Cast model and floating-point inputs to bfloat16 when the backend requires it."""
15 if not backend or backend not in _BF16_REQUIRED_BACKENDS:
16 return model, inputs_dict
17 model = model.to(dtype=torch.bfloat16)
18 inputs_dict = {
19 k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
20 for k, v in inputs_dict.items()
21 }
22 return model, inputs_dict

Callers 1

_context_parallel_workerFunction · 0.85

Calls 1

toMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…