(self, x, weight=None, verbose=False)
| 289 | |
| 290 | @autocast(enabled=False) |
| 291 | def forward(self, x, weight=None, verbose=False): |
| 292 | if weight is not None: |
| 293 | weight = weight * weight.numel() / weight.sum() |
| 294 | needs_codebook_dim = x.ndim < 4 |
| 295 | |
| 296 | x = x.float() |
| 297 | |
| 298 | if needs_codebook_dim: |
| 299 | x = rearrange(x, '... -> 1 ...') |
| 300 | |
| 301 | shape, dtype = x.shape, x.dtype |
| 302 | flatten = rearrange(x, 'h ... d -> h (...) d') |
| 303 | self.init_embed_(flatten) |
| 304 | embed = self.embed if not self.learnable_codebook else self.embed.detach() |
| 305 | dist = -torch.cdist(flatten, embed, p=2) |
| 306 | embed_ind = gumbel_sample(dist, dim=-1, temperature=self.sample_codebook_temp) |
| 307 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) |
| 308 | embed_ind = embed_ind.view(*shape[:-1]) |
| 309 | quantize = batched_embedding(embed_ind, self.embed) |
| 310 | |
| 311 | if self.training: |
| 312 | |
| 313 | if weight is not None: |
| 314 | cluster_size = (embed_onehot * weight).sum(dim=1) |
| 315 | else: |
| 316 | cluster_size = embed_onehot.sum(dim=1) |
| 317 | self.all_reduce_fn(cluster_size) |
| 318 | ema_inplace(self.cluster_size, cluster_size, self.decay) |
| 319 | |
| 320 | if weight is not None: |
| 321 | embed_sum = einsum('h n d, h n c -> h c d', flatten * weight, embed_onehot) |
| 322 | else: |
| 323 | embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot) |
| 324 | self.all_reduce_fn(embed_sum) |
| 325 | cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum() |
| 326 | |
| 327 | ema_inplace(self.embed, embed_sum / rearrange(cluster_size, '... -> ... 1'), self.decay) |
| 328 | self.expire_codes_(x, verbose) |
| 329 | |
| 330 | if needs_codebook_dim: |
| 331 | quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind)) |
| 332 | |
| 333 | return quantize, embed_ind |
| 334 | |
| 335 | # main class |
| 336 |
nothing calls this directly
no test coverage detected