(dtype: str)
| 123 | |
| 124 | |
| 125 | def get_dtype(dtype: str): |
| 126 | if dtype == "bfloat16": |
| 127 | return torch.bfloat16 |
| 128 | elif dtype == "bf16": |
| 129 | return torch.bfloat16 |
| 130 | elif dtype == "float16": |
| 131 | return torch.float16 |
| 132 | elif dtype == "fp16": |
| 133 | return torch.float16 |
| 134 | elif dtype == "float32": |
| 135 | return torch.float32 |
| 136 | elif dtype == "fp32": |
| 137 | return torch.float32 |
| 138 | else: |
| 139 | raise ValueError(f"Unsupported dtype: {dtype}") |
| 140 | |
| 141 | |
| 142 | def _has_mps() -> bool: |
no outgoing calls
no test coverage detected
searching dependent graphs…