()
| 510 | |
| 511 | # bf16 > fp16 > fp32 |
| 512 | def preferred_dtype(): |
| 513 | if get_accelerator().is_bf16_supported(): |
| 514 | return torch.bfloat16 |
| 515 | elif get_accelerator().is_fp16_supported(): |
| 516 | return torch.float16 |
| 517 | else: |
| 518 | return torch.float32 |
| 519 | |
| 520 | |
| 521 | class EnableDeterminism: |
searching dependent graphs…