| 254 | return self |
| 255 | |
| 256 | def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor: |
| 257 | is_disk = isinstance(self.device, str) and self.device.startswith("DISK") |
| 258 | if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype) |
| 259 | if self.uop is x.uop: return self # a self assign is a NOOP |
| 260 | # broadcast x (shape only, dtype must match) |
| 261 | if self.shape != x.shape: x = x._broadcast_to(self.shape) |
| 262 | if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}") |
| 263 | if not is_disk and x.uop.device is not None and self.device != x.device: |
| 264 | raise RuntimeError(f"assign device mismatch {self.device} != {x.device}") |
| 265 | if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}") |
| 266 | if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}") |
| 267 | |
| 268 | # TODO: this is a hack for writing to DISK. remove with working assign |
| 269 | if is_disk: |
| 270 | self._buffer().copyin(x._data()) |
| 271 | return self |
| 272 | # STORE+AFTER: STORE is the write effect (void), AFTER wraps the view for correct shape/ranging |
| 273 | assign = self.uop.after(self.uop.store(x.uop)) |
| 274 | if (base := self.uop.base).op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity(): |
| 275 | # view assign: replace at the buffer-identity level (e.g. RESHAPE(BUFFER)) so @function's substitution catches it |
| 276 | ib = self.uop |
| 277 | while not ib.has_buffer_identity() and ib is not base: ib = ib.src[0] |
| 278 | assigned_ib = ib.after(assign) |
| 279 | _apply_map_to_tensors({ib: assigned_ib}, name="Embed View Assign", walk=True) |
| 280 | else: |
| 281 | # simple assign |
| 282 | self.uop = assign |
| 283 | return self |
| 284 | |
| 285 | def _buffer(self) -> Buffer: |
| 286 | from tinygrad.engine.realize import capturing |