(self, w:Tensor, dtype:DTypeLike|None=None)
| 1431 | # *** image Tensor function replacements *** |
| 1432 | |
| 1433 | def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: |
| 1434 | # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) |
| 1435 | if not (self.ndim > 0 and w.ndim > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {self.ndim=}, {w.ndim=}") |
| 1436 | if self.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {self.shape} and {w.shape}") |
| 1437 | |
| 1438 | bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1] |
| 1439 | out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout,) |
| 1440 | |
| 1441 | # NOTE: with NHWC we can remove the transposes |
| 1442 | # bs x groups*cin x H x W |
| 1443 | cx = self.transpose(self.ndim-1, self.ndim-2).reshape(bs//groups, groups*cin, -1, 1) |
| 1444 | # groups*cout x cin x H, W |
| 1445 | cw = w.transpose(w.ndim-1, w.ndim-2).reshape(groups*cout, cin, 1, 1) |
| 1446 | return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) |
| 1447 | |
| 1448 | def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: |
| 1449 | dtsz = 2 if FLOAT16 else 4 |
no test coverage detected