| 90 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) |
| 91 | |
| 92 | def get_codebook_entry(self, indices, shape): |
| 93 | # shape specifying (batch, height, width, channel) |
| 94 | # TODO: check for more easy handling with nn.Embedding |
| 95 | min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) |
| 96 | min_encodings.scatter_(1, indices[:,None], 1) |
| 97 | |
| 98 | # get quantized latent vectors |
| 99 | z_q = torch.matmul(min_encodings.float(), self.embedding.weight) |
| 100 | |
| 101 | if shape is not None: |
| 102 | z_q = z_q.view(shape) |
| 103 | |
| 104 | # reshape back to match original input shape |
| 105 | z_q = z_q.permute(0, 3, 1, 2).contiguous() |
| 106 | |
| 107 | return z_q |
| 108 | |
| 109 | |
| 110 | class VectorQuantizer2(nn.Module): |