(self, hidden_states)
| 375 | self.embed = mx.zeros((config.codebook_size, config.codebook_dim)) |
| 376 | |
| 377 | def quantize(self, hidden_states): |
| 378 | embed = self.embed.T |
| 379 | scaled_states = hidden_states.square().sum(axis=1, keepdims=True) |
| 380 | dist = -( |
| 381 | scaled_states |
| 382 | - 2 * hidden_states @ embed |
| 383 | + embed.square().sum(axis=0, keepdims=True) |
| 384 | ) |
| 385 | embed_ind = dist.argmax(axis=-1) |
| 386 | return embed_ind |
| 387 | |
| 388 | def encode(self, hidden_states): |
| 389 | shape = hidden_states.shape |
no outgoing calls
no test coverage detected