(self, z)
| 137 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
| 138 | |
| 139 | def forward(self, z): |
| 140 | |
| 141 | N, width, T = z.shape |
| 142 | z = self.preprocess(z) |
| 143 | assert z.shape[-1] == self.e_dim |
| 144 | z_flattened = z.contiguous().view(-1, self.e_dim) |
| 145 | |
| 146 | # B x V |
| 147 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ |
| 148 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ |
| 149 | torch.matmul(z_flattened, self.embedding.weight.t()) |
| 150 | # B x 1 |
| 151 | min_encoding_indices = torch.argmin(d, dim=1) |
| 152 | z_q = self.embedding(min_encoding_indices).view(z.shape) |
| 153 | |
| 154 | # compute loss for embedding |
| 155 | loss = torch.mean((z_q - z.detach())**2) + self.beta * \ |
| 156 | torch.mean((z_q.detach() - z)**2) |
| 157 | |
| 158 | # preserve gradients |
| 159 | z_q = z + (z_q - z).detach() |
| 160 | z_q = z_q.view(N, T, -1).permute(0, 2, 1).contiguous() #(N, DIM, T) |
| 161 | |
| 162 | min_encodings = F.one_hot(min_encoding_indices, self.n_e).type(z.dtype) |
| 163 | e_mean = torch.mean(min_encodings, dim=0) |
| 164 | perplexity = torch.exp(-torch.sum(e_mean*torch.log(e_mean + 1e-10))) |
| 165 | return z_q, loss, perplexity |
| 166 | |
| 167 | def quantize(self, z): |
| 168 |
nothing calls this directly
no test coverage detected