Shards the tensor the same way as `y` (same devices and axis).
(self, y:Tensor)
| 403 | return self.replace(self.shard(devices, axis)) |
| 404 | |
| 405 | def shard_like(self, y:Tensor) -> Tensor: |
| 406 | """ |
| 407 | Shards the tensor the same way as `y` (same devices and axis). |
| 408 | """ |
| 409 | if y.device is None: return self |
| 410 | if isinstance(y.device, str): return self.to(y.device) |
| 411 | return self if isinstance(self.device, tuple) and (y.device, y.uop.axis) == (self.device, self.uop.axis) else self.shard(y.device, y.uop.axis) |
| 412 | |
| 413 | CHUNK_SIZE = 2**20 |
| 414 | def fs_load(self, size:int) -> Tensor: |