Resolve device from explicit request > env var > auto-detect. Args: requested: Device string from CLI arg. None or "auto" triggers env var lookup then auto-detection. Returns: Validated device string ("cuda", "mps", or "cpu"). Raises: Runtime
(requested: str | None = None)
| 69 | |
| 70 | |
| 71 | def resolve_device(requested: str | None = None) -> str: |
| 72 | """Resolve device from explicit request > env var > auto-detect. |
| 73 | |
| 74 | Args: |
| 75 | requested: Device string from CLI arg. None or "auto" triggers |
| 76 | env var lookup then auto-detection. |
| 77 | |
| 78 | Returns: |
| 79 | Validated device string ("cuda", "mps", or "cpu"). |
| 80 | |
| 81 | Raises: |
| 82 | RuntimeError: If the requested backend is unavailable. |
| 83 | """ |
| 84 | import torch |
| 85 | |
| 86 | # CLI arg takes priority, then env var, then auto |
| 87 | device = requested |
| 88 | if device is None or device == "auto": |
| 89 | device = os.environ.get(DEVICE_ENV_VAR, "auto") |
| 90 | |
| 91 | if device == "auto": |
| 92 | return detect_best_device() |
| 93 | |
| 94 | device = device.lower() |
| 95 | if device not in VALID_DEVICES: |
| 96 | raise RuntimeError(f"Unknown device '{device}'. Valid options: {', '.join(VALID_DEVICES)}") |
| 97 | |
| 98 | # Validate the explicit request |
| 99 | if device == "cuda": |
| 100 | if not torch.cuda.is_available(): |
| 101 | raise RuntimeError( |
| 102 | "CUDA requested but torch.cuda.is_available() is False. Install a CUDA-enabled PyTorch build." |
| 103 | ) |
| 104 | elif device == "mps": |
| 105 | if not hasattr(torch.backends, "mps"): |
| 106 | raise RuntimeError( |
| 107 | "MPS requested but this PyTorch build has no MPS support. Install PyTorch >= 1.12 with MPS backend." |
| 108 | ) |
| 109 | if not torch.backends.mps.is_available(): |
| 110 | raise RuntimeError( |
| 111 | "MPS requested but not available on this machine. Requires Apple Silicon (M1+) with macOS 12.3+." |
| 112 | ) |
| 113 | |
| 114 | return device |
| 115 | |
| 116 | |
| 117 | @dataclass |