| 13 | |
| 14 | ALL_DEVICES = ["METAL", "AMD", "NV", "CUDA", "QCOM", "CL", "CPU", "DSP", "WEBGPU"] |
| 15 | class _Device: |
| 16 | def __init__(self) -> None: |
| 17 | self._devices = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] |
| 18 | self._opened_devices:set[str] = set() |
| 19 | @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none |
| 20 | def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):]) |
| 21 | # NOTE: you can't cache canonicalize in case Device.DEFAULT changes |
| 22 | def canonicalize(self, device:str|None) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT) |
| 23 | def __getitem__(self, ix:str) -> Compiled: |
| 24 | ix = self.canonicalize(ix) |
| 25 | assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed" |
| 26 | return self.__get_canonicalized_item(ix) |
| 27 | @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none |
| 28 | def __get_canonicalized_item(self, ix:str) -> Compiled: |
| 29 | base = (__package__ or __name__).split('.')[0] # tinygrad |
| 30 | x = ix.split(":")[0].lower() |
| 31 | ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \ |
| 32 | if (cname.lower() == x + "device")][0](ix) |
| 33 | if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}") |
| 34 | self._opened_devices.add(ix) |
| 35 | return ret |
| 36 | @property |
| 37 | def default(self) -> Compiled: return self[self.DEFAULT] |
| 38 | def get_available_devices(self) -> Iterator[str]: |
| 39 | for device in ALL_DEVICES: |
| 40 | with contextlib.suppress(Exception): yield self[device].device |
| 41 | @property |
| 42 | def DEFAULT(self) -> str: return DEV.device or self._select_device |
| 43 | @DEFAULT.setter |
| 44 | def DEFAULT(self, v): raise AttributeError(f'setting Device.DEFAULT is deprecated, use "with Context(DEV={v!r})" or "DEV.value = {v!r}"') |
| 45 | @functools.cached_property |
| 46 | def _select_device(self) -> str: |
| 47 | assert (dev:=next((d for d in self._devices if d not in ["DISK", "TINYFS", "NPY"] and getenv(d) == 1), None)) is None, \ |
| 48 | f"{dev}=1 is deprecated, use DEV={dev} instead" |
| 49 | try: |
| 50 | device = next(self.get_available_devices()) |
| 51 | os.environ["DEV"] = device # we set this in environment for spawned children |
| 52 | return device |
| 53 | except StopIteration as exc: raise RuntimeError("no usable devices") from exc |
| 54 | Device: _Device = _Device() |
| 55 | atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices]) |
| 56 |
no outgoing calls
no test coverage detected
searching dependent graphs…