MCPcopy
hub / github.com/tinygrad/tinygrad / image_conv2d

Method image_conv2d

tinygrad/tensor.py:1448–1532  ·  view source on GitHub ↗
(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None)

Source from the content-addressed store, hash-verified

1446 return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2)
1447
1448 def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
1449 dtsz = 2 if FLOAT16 else 4
1450
1451 (bs,_,_,_), (cout,cin,H,W) = self.shape, weight.shape
1452 assert isinstance(cin, int) and isinstance(cout, int)
1453 x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)
1454
1455 padding_neg, padding_pos = [min(0, p) for p in resolve_pool_pads(padding, 2)], [max(0, p) for p in resolve_pool_pads(padding, 2)]
1456 x = x.pad(padding_neg)
1457 iy, ix = x.shape[2:]
1458
1459 # hack for non multiples of 4 on cin
1460 if cin % 4 != 0 and not (cin == 1 and groups%4 == 0):
1461 new_cin = round_up(cin, 4)
1462 w = w.pad_to(None, None, new_cin, None, None)
1463 x = x.reshape(bs, groups, cin, iy, ix)
1464 x = x.pad_to(None, None, new_cin, None, None).reshape(bs, groups*new_cin, iy, ix)
1465 cin = new_cin
1466
1467 # hack for non multiples of 4 on rcout
1468 added_output_channels = 0
1469 if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0):
1470 added_output_channels = 4 - (rcout % 4)
1471 rcout += added_output_channels
1472 cout = groups * rcout
1473 w = w.pad_to(None, rcout, None, None, None)
1474
1475 # packed (note: flipping bs and iy would make the auto-padding work)
1476 x = x.permute(0,2,3,1)
1477 cin_last = iy == 1 and ix == 1
1478 if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1)
1479 elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3)
1480 else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1)
1481
1482 def is_pow2(v): return v > 0 and v & (v - 1) == 0
1483 # pad dimension i to amt with invalids
1484 def ipad(t, i, amt):
1485 shape = (None,)*i + (amt,) + (None,)*(t.ndim-i-1)
1486 return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid) if amt != t.shape[i] else t
1487 # align a dimension, use at to specify the dimension to pad in, defaults to first
1488 def pad_align(t, dim, at=None, force=False):
1489 # align to 64 pixels when height is real, otherwise 64 bytes is sufficient
1490 align = (64 // dtsz) if prod(t.shape[:dim]) == 1 or prod(t.shape) < 16384 * 4 else 256
1491 return ipad(t, at:=at or dim, round_up(t.shape[at] + int(force), align // math.gcd(prod(t.shape[dim:]) // t.shape[at], align)))
1492
1493 # bank conflicts
1494 bank_conflict = cin >= 8 and is_pow2(cin // 4)
1495 if bank_conflict:
1496 x, w = pad_align(x.reshape(bs, iy, ix, groups, cin // 4, 4), 2, at=4, force=True), pad_align(w, 1, at=2, force=True)
1497 else: x, w = pad_align(x, 2), pad_align(w, 1)
1498
1499 # contiguous creates the image, and early realize static weights (TODO: test for the static weight)
1500 if FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float)
1501 else: x, w = x.contiguous(), w.contiguous()
1502
1503 # undo alignment hacks
1504 if bank_conflict: x, w = x[:, :, :, :, :cin // 4, :], w[:, :, :cin // 4, ...]
1505 else: x, w = x[:, :, :ix, :], w[:, :H, ...]

Callers 3

conv2dMethod · 0.95
image_dotMethod · 0.80

Calls 11

resolve_pool_padsFunction · 0.90
round_upFunction · 0.90
reshapeMethod · 0.80
pad_toMethod · 0.80
permuteMethod · 0.80
_poolMethod · 0.80
sumMethod · 0.80
padMethod · 0.45
castMethod · 0.45
contiguousMethod · 0.45
addMethod · 0.45

Tested by 1