(self, indices)
| 177 | return min_encoding_indices |
| 178 | |
| 179 | def dequantize(self, indices): |
| 180 | |
| 181 | index_flattened = indices.view(-1) |
| 182 | z_q = self.embedding(index_flattened) |
| 183 | z_q = z_q.view(indices.shape + (self.e_dim, )).contiguous() |
| 184 | return z_q |
| 185 | |
| 186 | def preprocess(self, x): |
| 187 | # NCT -> NTC -> [NT, C] |