Cast model and floating-point inputs to bfloat16 when the backend requires it.
(backend, model, inputs_dict)
| 11 | |
| 12 | |
| 13 | def _maybe_cast_to_bf16(backend, model, inputs_dict): |
| 14 | """Cast model and floating-point inputs to bfloat16 when the backend requires it.""" |
| 15 | if not backend or backend not in _BF16_REQUIRED_BACKENDS: |
| 16 | return model, inputs_dict |
| 17 | model = model.to(dtype=torch.bfloat16) |
| 18 | inputs_dict = { |
| 19 | k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v |
| 20 | for k, v in inputs_dict.items() |
| 21 | } |
| 22 | return model, inputs_dict |
no test coverage detected
searching dependent graphs…