(x)
| 65 | |
| 66 | |
| 67 | def numpy_to_torch(x): |
| 68 | if x.dtype == np_bfloat16: |
| 69 | return torch.from_numpy(x.view(np.int16)).view(torch.bfloat16) |
| 70 | elif x.dtype == np_float8: |
| 71 | return torch.from_numpy(x.view(np.int8)).view(torch.float8_e4m3fn) |
| 72 | else: |
| 73 | return torch.from_numpy(x) |
| 74 | |
| 75 | |
| 76 | def numpy_to_dtype(x, dtype: str): |
no test coverage detected