MCPcopy
hub / github.com/zju3dv/4K4D / forward

Method forward

easyvolcap/utils/vq_utils.py:291–333  ·  view source on GitHub ↗
(self, x, weight=None, verbose=False)

Source from the content-addressed store, hash-verified

289
290 @autocast(enabled=False)
291 def forward(self, x, weight=None, verbose=False):
292 if weight is not None:
293 weight = weight * weight.numel() / weight.sum()
294 needs_codebook_dim = x.ndim < 4
295
296 x = x.float()
297
298 if needs_codebook_dim:
299 x = rearrange(x, '... -> 1 ...')
300
301 shape, dtype = x.shape, x.dtype
302 flatten = rearrange(x, 'h ... d -> h (...) d')
303 self.init_embed_(flatten)
304 embed = self.embed if not self.learnable_codebook else self.embed.detach()
305 dist = -torch.cdist(flatten, embed, p=2)
306 embed_ind = gumbel_sample(dist, dim=-1, temperature=self.sample_codebook_temp)
307 embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
308 embed_ind = embed_ind.view(*shape[:-1])
309 quantize = batched_embedding(embed_ind, self.embed)
310
311 if self.training:
312
313 if weight is not None:
314 cluster_size = (embed_onehot * weight).sum(dim=1)
315 else:
316 cluster_size = embed_onehot.sum(dim=1)
317 self.all_reduce_fn(cluster_size)
318 ema_inplace(self.cluster_size, cluster_size, self.decay)
319
320 if weight is not None:
321 embed_sum = einsum('h n d, h n c -> h c d', flatten * weight, embed_onehot)
322 else:
323 embed_sum = einsum('h n d, h n c -> h c d', flatten, embed_onehot)
324 self.all_reduce_fn(embed_sum)
325 cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum()
326
327 ema_inplace(self.embed, embed_sum / rearrange(cluster_size, '... -> ... 1'), self.decay)
328 self.expire_codes_(x, verbose)
329
330 if needs_codebook_dim:
331 quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
332
333 return quantize, embed_ind
334
335# main class
336

Callers

nothing calls this directly

Calls 7

init_embed_Method · 0.95
expire_codes_Method · 0.95
gumbel_sampleFunction · 0.85
batched_embeddingFunction · 0.85
ema_inplaceFunction · 0.85
laplace_smoothingFunction · 0.85
typeMethod · 0.45

Tested by

no test coverage detected