| 265 | self.initted.data.copy_(torch.Tensor([True])) |
| 266 | |
| 267 | def replace(self, batch_samples, batch_mask): |
| 268 | batch_samples = l2norm(batch_samples) |
| 269 | |
| 270 | for ind, (samples, mask) in enumerate(zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0))): |
| 271 | if not torch.any(mask): |
| 272 | continue |
| 273 | |
| 274 | sampled = self.sample_fn(rearrange(samples, '... -> 1 ...'), mask.sum().item()) |
| 275 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') |
| 276 | |
| 277 | def expire_codes_(self, batch_samples, verbose): |
| 278 | if self.threshold_ema_dead_code == 0: |