| 275 | self.embed.data[ind][mask] = rearrange(sampled, '1 ... -> ...') |
| 276 | |
| 277 | def expire_codes_(self, batch_samples, verbose): |
| 278 | if self.threshold_ema_dead_code == 0: |
| 279 | return |
| 280 | |
| 281 | expired_codes = self.cluster_size < self.threshold_ema_dead_code |
| 282 | |
| 283 | if not torch.any(expired_codes): |
| 284 | return |
| 285 | if verbose: |
| 286 | print(f'expire code count: {expired_codes.sum()}') |
| 287 | batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d') |
| 288 | self.replace(batch_samples, batch_mask=expired_codes) |
| 289 | |
| 290 | @autocast(enabled=False) |
| 291 | def forward(self, x, weight=None, verbose=False): |