Args: k: key feature queue: memory buffer
(self, k, queue)
| 15 | self.index = (self.index + bsz) % self.K |
| 16 | |
| 17 | def _update_memory(self, k, queue): |
| 18 | """ |
| 19 | Args: |
| 20 | k: key feature |
| 21 | queue: memory buffer |
| 22 | """ |
| 23 | with torch.no_grad(): |
| 24 | num_neg = k.shape[0] |
| 25 | out_ids = torch.arange(num_neg).cuda() |
| 26 | out_ids = torch.fmod(out_ids + self.index, self.K).long() |
| 27 | queue.index_copy_(0, out_ids, k) |
| 28 | |
| 29 | def _compute_logit(self, q, k, queue): |
| 30 | """ |