MCPcopy
hub / github.com/tinygrad/tinygrad / keccak

Method keccak

tinygrad/tensor.py:1106–1159  ·  view source on GitHub ↗

Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default. ```python exec="false" source="above" session="tensor" result="python" t = Tensor(b"Hello World!").keccak() print(t.data().hex()) ```

(self, cfg:str|tuple[int, int]="sha3_256")

Source from the content-addressed store, hash-verified

1104 # ***** reduce ops *****
1105
1106 def keccak(self, cfg:str|tuple[int, int]="sha3_256"):
1107 """
1108 Calculates a Keccak hash over the last dimension. Uses "sha3_256" by default.
1109
1110 ```python exec="false" source="above" session="tensor" result="python"
1111 t = Tensor(b"Hello World!").keccak()
1112 print(t.data().hex())
1113 ```
1114 """
1115
1116 # https://keccak.team/keccak_specs_summary.html
1117
1118 def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64):
1119 # TODO: contiguous is here for compile speed
1120 return Tensor.stack(*(Tensor(v, dtype=dtype, device=self.device) for v in l)).contiguous()
1121 rot_offsets = [44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6, 25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2]
1122 rot_offsets_v0, rot_offsets_v1 = ctensor([0] + [1 << v for v in rot_offsets]), ctensor([1] + [1 << (64 - v) for v in rot_offsets])
1123
1124 # calculated from π step
1125 reorder_indexes = ctensor([0,6,12,18,24,3,9,10,16,22,1,7,13,19,20,4,5,11,17,23,2,8,14,15,21], dtype=dtypes.int32)
1126 rnd_const_masks = [ctensor([v]).pad((0, 24)) for v in (1, 0x8082, 0x800000000000808a, 0x8000000080008000, 0x808b, 0x80000001, 0x8000000080008081,
1127 0x8000000000008009, 0x8a, 0x88, 0x80008009, 0x8000000a, 0x8000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
1128 0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)]
1129
1130 rate, dsbyte = {"sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31)}[cfg] if isinstance(cfg, str) else cfg
1131 data = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1])
1132 data_pad = rate - data.shape[-1] % rate
1133 # pad batches then pad blocks
1134 data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad_to(None, None, 200)
1135
1136 # create pad mask
1137 lbe = (data.shape[1] - 1) * 200 + rate - data_pad
1138 if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (200 - rate, 0)]
1139 else: mb = [(lbe, 0), (1, dsbyte), (data_pad - 2, 0), (1, 0x80), (200 - rate, 0)]
1140 pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)).unsqueeze(0)
1141
1142 data = (data.flatten(1) ^ pad_mask).reshape(*data.shape[:2], 200).bitcast(dtypes.uint64)
1143
1144 state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64, buffer=False)
1145 for k in range(int(data.shape[1])):
1146 state = state ^ data[:, k]
1147 for i in range(24): # f1600
1148 # θ step
1149 p = state.reshape(bs, 5, 5).transpose(2, 1)
1150 t1 = (p[:,:,0] ^ p[:,:,1] ^ p[:,:,2] ^ p[:,:,3] ^ p[:,:,4]).roll(-1, 1) # xor reduce
1151 state = state ^ (t1.roll(2, 1).bitwise_xor((t1 << 1) ^ (t1 >> 63)).unsqueeze(2).expand(bs, 5, 5).transpose(2, 1).flatten(1))
1152 # ρ and π steps
1153 state = state[:, reorder_indexes]
1154 state = (state * rot_offsets_v0).bitwise_or(state // rot_offsets_v1).reshape(bs, 5, 5)
1155 # χ and ι step
1156 state = state.bitwise_xor(~state.roll(shifts=-1, dims=2) & state.roll(shifts=-2, dims=2))
1157 state = state.flatten(1) ^ rnd_const_masks[i]
1158 # NOTE: there was a kernelize here to prevent internal stack from growing propotional to data size, do we need something else?
1159 return state.bitcast(dtypes.uint8)[:,:(obytes:=(200 - rate) // 2)].reshape(*self.shape[:-1], obytes)
1160
1161 def _hash_1mb(self) -> Tensor:
1162 assert self.dtype == dtypes.uint8, "only support uint8 tensors for hashing"

Callers 9

_hash_1mbMethod · 0.80
test_shape_keepingMethod · 0.80
_test_presetMethod · 0.80
test_referencedMethod · 0.80
test_longMethod · 0.80
test_variable_bsMethod · 0.80
fMethod · 0.80
check_nist_vectorsMethod · 0.80
hasherFunction · 0.80

Calls 15

bitcastMethod · 0.95
prodFunction · 0.90
TensorClass · 0.85
reshapeMethod · 0.80
pad_toMethod · 0.80
unsqueezeMethod · 0.80
catMethod · 0.80
expandMethod · 0.80
flattenMethod · 0.80
zerosMethod · 0.80
rollMethod · 0.80
bitwise_xorMethod · 0.80

Tested by 7

test_shape_keepingMethod · 0.64
_test_presetMethod · 0.64
test_referencedMethod · 0.64
test_longMethod · 0.64
test_variable_bsMethod · 0.64
fMethod · 0.64
check_nist_vectorsMethod · 0.64