MCPcopy
hub / github.com/tinygrad/tinygrad / assign

Method assign

tinygrad/tensor.py:256–283  ·  view source on GitHub ↗
(self, x:Tensor|PyConst|list|tuple)

Source from the content-addressed store, hash-verified

254 return self
255
256 def assign(self, x:Tensor|PyConst|list|tuple) -> Tensor:
257 is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
258 if not isinstance(x, Tensor): x = Tensor(x, device="CPU" if is_disk else self.device, dtype=self.dtype)
259 if self.uop is x.uop: return self # a self assign is a NOOP
260 # broadcast x (shape only, dtype must match)
261 if self.shape != x.shape: x = x._broadcast_to(self.shape)
262 if self.shape != x.shape: raise RuntimeError(f"assign shape mismatch {self.shape} != {x.shape}")
263 if not is_disk and x.uop.device is not None and self.device != x.device:
264 raise RuntimeError(f"assign device mismatch {self.device} != {x.device}")
265 if not is_disk and self.dtype != x.dtype: raise RuntimeError(f"assign dtype mismatch {self.dtype} != {x.dtype}")
266 if isinstance(self.device, tuple) and self.uop.axis != x.uop.axis: raise RuntimeError(f"multi axis mismatch {self.uop.axis} != {x.uop.axis}")
267
268 # TODO: this is a hack for writing to DISK. remove with working assign
269 if is_disk:
270 self._buffer().copyin(x._data())
271 return self
272 # STORE+AFTER: STORE is the write effect (void), AFTER wraps the view for correct shape/ranging
273 assign = self.uop.after(self.uop.store(x.uop))
274 if (base := self.uop.base).op in {Ops.BUFFER, Ops.AFTER} and self.uop is not base and not self.uop.has_buffer_identity():
275 # view assign: replace at the buffer-identity level (e.g. RESHAPE(BUFFER)) so @function's substitution catches it
276 ib = self.uop
277 while not ib.has_buffer_identity() and ib is not base: ib = ib.src[0]
278 assigned_ib = ib.after(assign)
279 _apply_map_to_tensors({ib: assigned_ib}, name="Embed View Assign", walk=True)
280 else:
281 # simple assign
282 self.uop = assign
283 return self
284
285 def _buffer(self) -> Buffer:
286 from tinygrad.engine.realize import capturing

Callers 15

__setitem__Method · 0.95
__iadd__Method · 0.95
__isub__Method · 0.95
__imul__Method · 0.95
__itruediv__Method · 0.95
__ifloordiv__Method · 0.95
__ipow__Method · 0.95
__iand__Method · 0.95
__ior__Method · 0.95
__ixor__Method · 0.95
__ilshift__Method · 0.95
__irshift__Method · 0.95

Calls 9

_bufferMethod · 0.95
_dataMethod · 0.95
TensorClass · 0.85
_apply_map_to_tensorsFunction · 0.85
_broadcast_toMethod · 0.80
copyinMethod · 0.80
has_buffer_identityMethod · 0.80
afterMethod · 0.45
storeMethod · 0.45

Tested by 5

test_implicit_inputMethod · 0.76
test_implicit_ioMethod · 0.76
copy_weightsFunction · 0.36