MCPcopy
hub / github.com/zju3dv/4K4D / replace

Method replace

easyvolcap/utils/vq_utils.py:267–275  ·  view source on GitHub ↗
(self, batch_samples, batch_mask)

Source from the content-addressed store, hash-verified

265 self.initted.data.copy_(torch.Tensor([True]))
266
267 def replace(self, batch_samples, batch_mask):
268 batch_samples = l2norm(batch_samples)
269
270 for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))):
271 if not torch.any(mask):
272 continue
273
274 sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item())
275 self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...')
276
277 def expire_codes_(self, batch_samples, verbose):
278 if self.threshold_ema_dead_code == 0:

Callers 1

expire_codes_Method · 0.95

Calls 1

l2normFunction · 0.85

Tested by

no test coverage detected