(device)
| 1387 | |
| 1388 | # Utils for custom and alternative accelerator devices |
| 1389 | def _is_torch_fp16_available(device): |
| 1390 | if not is_torch_available(): |
| 1391 | return False |
| 1392 | |
| 1393 | import torch |
| 1394 | |
| 1395 | device = torch.device(device) |
| 1396 | |
| 1397 | try: |
| 1398 | x = torch.zeros((2, 2), dtype=torch.float16).to(device) |
| 1399 | _ = torch.mul(x, x) |
| 1400 | return True |
| 1401 | |
| 1402 | except Exception as e: |
| 1403 | if device.type == "cuda": |
| 1404 | raise ValueError( |
| 1405 | f"You have passed a device of type 'cuda' which should work with 'fp16', but 'cuda' does not seem to be correctly installed on your machine: {e}" |
| 1406 | ) |
| 1407 | |
| 1408 | return False |
| 1409 | |
| 1410 | |
| 1411 | def _is_torch_fp64_available(device): |
no test coverage detected
searching dependent graphs…