(self, z)
| 165 | return z_q, loss, perplexity |
| 166 | |
| 167 | def quantize(self, z): |
| 168 | |
| 169 | assert z.shape[-1] == self.e_dim |
| 170 | |
| 171 | # B x V |
| 172 | d = torch.sum(z ** 2, dim=1, keepdim=True) + \ |
| 173 | torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ |
| 174 | torch.matmul(z, self.embedding.weight.t()) |
| 175 | # B x 1 |
| 176 | min_encoding_indices = torch.argmin(d, dim=1) |
| 177 | return min_encoding_indices |
| 178 | |
| 179 | def dequantize(self, indices): |
| 180 |