(self, hidden_states)
| 386 | return embed_ind |
| 387 | |
| 388 | def encode(self, hidden_states): |
| 389 | shape = hidden_states.shape |
| 390 | hidden_states = hidden_states.reshape((-1, shape[-1])) |
| 391 | embed_ind = self.quantize(hidden_states) |
| 392 | embed_ind = embed_ind.reshape(*shape[:-1]) |
| 393 | return embed_ind |
| 394 | |
| 395 | def decode(self, embed_ind): |
| 396 | return self.embed[embed_ind] |