(x: torch.Tensor)
| 54 | |
| 55 | |
| 56 | def torch_to_numpy(x: torch.Tensor): |
| 57 | assert isinstance(x, torch.Tensor), \ |
| 58 | f'x must be a torch.Tensor object, but got {type(x)}.' |
| 59 | if x.dtype == torch.bfloat16: |
| 60 | return x.view(torch.int16).detach().cpu().numpy().view(np_bfloat16) |
| 61 | elif x.dtype == torch.float8_e4m3fn: |
| 62 | return x.view(torch.int8).detach().cpu().numpy().view(np_float8) |
| 63 | else: |
| 64 | return x.detach().cpu().numpy() |
| 65 | |
| 66 | |
| 67 | def numpy_to_torch(x): |