| 1446 | return cx.image_conv2d(cw, groups=groups, dtype=dtype).reshape(out_shape_t).transpose(self.ndim-1, self.ndim-2) |
| 1447 | |
| 1448 | def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: |
| 1449 | dtsz = 2 if FLOAT16 else 4 |
| 1450 | |
| 1451 | (bs,_,_,_), (cout,cin,H,W) = self.shape, weight.shape |
| 1452 | assert isinstance(cin, int) and isinstance(cout, int) |
| 1453 | x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) |
| 1454 | |
| 1455 | padding_neg, padding_pos = [min(0, p) for p in resolve_pool_pads(padding, 2)], [max(0, p) for p in resolve_pool_pads(padding, 2)] |
| 1456 | x = x.pad(padding_neg) |
| 1457 | iy, ix = x.shape[2:] |
| 1458 | |
| 1459 | # hack for non multiples of 4 on cin |
| 1460 | if cin % 4 != 0 and not (cin == 1 and groups%4 == 0): |
| 1461 | new_cin = round_up(cin, 4) |
| 1462 | w = w.pad_to(None, None, new_cin, None, None) |
| 1463 | x = x.reshape(bs, groups, cin, iy, ix) |
| 1464 | x = x.pad_to(None, None, new_cin, None, None).reshape(bs, groups*new_cin, iy, ix) |
| 1465 | cin = new_cin |
| 1466 | |
| 1467 | # hack for non multiples of 4 on rcout |
| 1468 | added_output_channels = 0 |
| 1469 | if rcout % 4 != 0 and not (rcout == 1 and groups%4 == 0): |
| 1470 | added_output_channels = 4 - (rcout % 4) |
| 1471 | rcout += added_output_channels |
| 1472 | cout = groups * rcout |
| 1473 | w = w.pad_to(None, rcout, None, None, None) |
| 1474 | |
| 1475 | # packed (note: flipping bs and iy would make the auto-padding work) |
| 1476 | x = x.permute(0,2,3,1) |
| 1477 | cin_last = iy == 1 and ix == 1 |
| 1478 | if cin == 1: w = w.reshape(cout//4,4,H,W).permute(0,2,3,1) |
| 1479 | elif cin_last: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,1,3) |
| 1480 | else: w = w.reshape(cout//4,4,cin//4,4,H,W).permute(0,4,2,5,3,1) |
| 1481 | |
| 1482 | def is_pow2(v): return v > 0 and v & (v - 1) == 0 |
| 1483 | # pad dimension i to amt with invalids |
| 1484 | def ipad(t, i, amt): |
| 1485 | shape = (None,)*i + (amt,) + (None,)*(t.ndim-i-1) |
| 1486 | return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid) if amt != t.shape[i] else t |
| 1487 | # align a dimension, use at to specify the dimension to pad in, defaults to first |
| 1488 | def pad_align(t, dim, at=None, force=False): |
| 1489 | # align to 64 pixels when height is real, otherwise 64 bytes is sufficient |
| 1490 | align = (64 // dtsz) if prod(t.shape[:dim]) == 1 or prod(t.shape) < 16384 * 4 else 256 |
| 1491 | return ipad(t, at:=at or dim, round_up(t.shape[at] + int(force), align // math.gcd(prod(t.shape[dim:]) // t.shape[at], align))) |
| 1492 | |
| 1493 | # bank conflicts |
| 1494 | bank_conflict = cin >= 8 and is_pow2(cin // 4) |
| 1495 | if bank_conflict: |
| 1496 | x, w = pad_align(x.reshape(bs, iy, ix, groups, cin // 4, 4), 2, at=4, force=True), pad_align(w, 1, at=2, force=True) |
| 1497 | else: x, w = pad_align(x, 2), pad_align(w, 1) |
| 1498 | |
| 1499 | # contiguous creates the image, and early realize static weights (TODO: test for the static weight) |
| 1500 | if FLOAT16: x, w = x.cast(dtypes.half).contiguous().cast(dtypes.float), w.cast(dtypes.half).contiguous().cast(dtypes.float) |
| 1501 | else: x, w = x.contiguous(), w.contiguous() |
| 1502 | |
| 1503 | # undo alignment hacks |
| 1504 | if bank_conflict: x, w = x[:, :, :, :, :cin // 4, :], w[:, :, :cin // 4, ...] |
| 1505 | else: x, w = x[:, :, :ix, :], w[:, :H, ...] |