Shards the tensor across the given devices. Optionally specify which axis to shard on. ```python exec="true" source="above" session="tensor" result="python" t = Tensor.empty(2, 4) print(t.shard((t.device, t.device), axis=1).uop) ```
(self, devices:tuple[str, ...], axis:int|None=None)
| 381 | return self.replace(real) |
| 382 | |
| 383 | def shard(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: |
| 384 | """ |
| 385 | Shards the tensor across the given devices. Optionally specify which axis to shard on. |
| 386 | |
| 387 | ```python exec="true" source="above" session="tensor" result="python" |
| 388 | t = Tensor.empty(2, 4) |
| 389 | print(t.shard((t.device, t.device), axis=1).uop) |
| 390 | ``` |
| 391 | """ |
| 392 | if self.uop.device is None: return self |
| 393 | if not isinstance(self.device, str): raise RuntimeError("can't shard a multi-device tensor") |
| 394 | if len(devices) == 1: return self.to(devices[0]) |
| 395 | devices = cast(tuple[str, ...], canonicalize_device(devices)) |
| 396 | uop = self.uop.shard(devices, self._resolve_dim(axis)) if axis is not None else self.uop.copy_to_device(devices) |
| 397 | return Tensor(uop).is_param_(self.is_param) |
| 398 | |
| 399 | def shard_(self, devices:tuple[str, ...], axis:int|None=None) -> Tensor: |
| 400 | """ |