Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default. ```python exec="false" source="above" session="tensor" result="python" t = Tensor(b"Hello World!").keccak() print(t.data().hex()) ```
(self, cfg:str|tuple[int, int]="sha3_256")
| 1104 | # ***** reduce ops ***** |
| 1105 | |
| 1106 | def keccak(self, cfg:str|tuple[int, int]="sha3_256"): |
| 1107 | """ |
| 1108 | Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default. |
| 1109 | |
| 1110 | ```python exec="false" source="above" session="tensor" result="python" |
| 1111 | t = Tensor(b"Hello World!").keccak() |
| 1112 | print(t.data().hex()) |
| 1113 | ``` |
| 1114 | """ |
| 1115 | |
| 1116 | # https://keccak.team/keccak_specs_summary.html |
| 1117 | |
| 1118 | def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64): |
| 1119 | # TODO: contiguous is here for compile speed |
| 1120 | return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous() |
| 1121 | rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2] |
| 1122 | rot_offsets_v0, rot_offsets_v1 = ctensor([0] + [1 << v for v in rot_offsets]), ctensor([1] + [1 << (64 - v) for v in rot_offsets]) |
| 1123 | |
| 1124 | # calculated from π step |
| 1125 | reorder_indexes = ctensor([0,6,12,18,24,3,9,10,16,22,1,7,13,19,20,4,5,11,17,23,2,8,14,15,21], dtype=dtypes.int32) |
| 1126 | rnd_const_masks = [ctensor([v]).pad((0, 24)) for v in (1, 0x8082, 0x800000000000808a, 0x8000000080008000, 0x808b, 0x80000001, 0x8000000080008081, |
| 1127 | 0x8000000000008009, 0x8a, 0x88, 0x80008009, 0x8000000a, 0x8000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003, |
| 1128 | 0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)] |
| 1129 | |
| 1130 | rate, dsbyte = {"sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31)}[cfg] if isinstance(cfg, str) else cfg |
| 1131 | data = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1]) |
| 1132 | data_pad = rate - data.shape[-1] % rate |
| 1133 | # pad batches then pad blocks |
| 1134 | data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad_to(None, None, 200) |
| 1135 | |
| 1136 | # create pad mask |
| 1137 | lbe = (data.shape[1] - 1) * 200 + rate - data_pad |
| 1138 | if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (200 - rate, 0)] |
| 1139 | else: mb = [(lbe, 0), (1, dsbyte), (data_pad - 2, 0), (1, 0x80), (200 - rate, 0)] |
| 1140 | pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)).unsqueeze(0) |
| 1141 | |
| 1142 | data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64) |
| 1143 | |
| 1144 | state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64, buffer=False) |
| 1145 | for k in range(int(data.shape[1])): |
| 1146 | state = state ^ data[:, k] |
| 1147 | for i in range(24): # f1600 |
| 1148 | # θ step |
| 1149 | p = state.reshape(bs, 5, 5).transpose(2, 1) |
| 1150 | t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce |
| 1151 | state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1)) |
| 1152 | # ρ and π steps |
| 1153 | state = state[:, reorder_indexes] |
| 1154 | state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5) |
| 1155 | # χ and ι step |
| 1156 | state = state.bitwise_xor(~state.roll(shifts=-1, dims=2) & state.roll(shifts=-2, dims=2)) |
| 1157 | state = state.flatten(1) ^ rnd_const_masks[i] |
| 1158 | # NOTE: there was a kernelize here to prevent internal stack from growing propotional to data size, do we need something else? |
| 1159 | return state.bitcast(dtypes.uint8)[:,:(obytes:=(200 - rate) // 2)].reshape(*self.shape[:-1], obytes) |
| 1160 | |
| 1161 | def _hash_1mb(self) -> Tensor: |
| 1162 | assert self.dtype == dtypes.uint8, "only support uint8 tensors for hashing" |