(self, input_pos, k_val, v_val)
| 175 | self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) |
| 176 | |
| 177 | def update(self, input_pos, k_val, v_val): |
| 178 | # input_pos: [S], k_val: [B, H, S, D] |
| 179 | assert input_pos.shape[0] == k_val.shape[2] |
| 180 | k_out = self.k_cache |
| 181 | v_out = self.v_cache |
| 182 | k_out[:, :, input_pos] = k_val |
| 183 | v_out[:, :, input_pos] = v_val |
| 184 | |
| 185 | return k_out, v_out |
| 186 | |
| 187 | |
| 188 | class Attention(nn.Module): |
no outgoing calls
no test coverage detected