MCPcopy
hub / github.com/tinygrad/tinygrad / image_dot

Method image_dot

tinygrad/tensor.py:1433–1446  ·  view source on GitHub ↗
(self, w:Tensor, dtype:DTypeLike|None=None)

Source from the content-addressed store, hash-verified

1431 # *** image Tensor function replacements ***
1432
1433 def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
1434 # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1)
1435 if not (self.ndim > 0 and w.ndim > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {self.ndim=}, {w.ndim=}")
1436 if self.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {self.shape} and {w.shape}")
1437
1438 bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1]
1439 out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout,)
1440
1441 # NOTE: with NHWC we can remove the transposes
1442 # bs x groups*cin x H x W
1443 cx = self.transpose(self.ndim-1, self.ndim-2).reshape(bs//groups, groups*cin, -1, 1)
1444 # groups*cout x cin x H, W
1445 cw = w.transpose(w.ndim-1, w.ndim-2).reshape(groups*cout, cin, 1, 1)
1446 return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
1447
1448 def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
1449 dtsz = 2 if FLOAT16 else 4

Callers 1

dotMethod · 0.95

Calls 4

prodFunction · 0.90
reshapeMethod · 0.80
image_conv2dMethod · 0.80
transposeMethod · 0.45

Tested by

no test coverage detected