Args: q: query on current node k: key on current node q_jig: jigsaw query all_k: gather of feats across nodes; otherwise use q
(self, q, k, q_jig=None, all_k=None)
| 58 | self.memory = F.normalize(self.memory) |
| 59 | |
| 60 | def forward(self, q, k, q_jig=None, all_k=None): |
| 61 | """ |
| 62 | Args: |
| 63 | q: query on current node |
| 64 | k: key on current node |
| 65 | q_jig: jigsaw query |
| 66 | all_k: gather of feats across nodes; otherwise use q |
| 67 | """ |
| 68 | bsz = q.size(0) |
| 69 | k = k.detach() |
| 70 | |
| 71 | # compute logit |
| 72 | queue = self.memory.clone().detach() |
| 73 | logits = self._compute_logit(q, k, queue) |
| 74 | if q_jig is not None: |
| 75 | logits_jig = self._compute_logit(q_jig, k, queue) |
| 76 | |
| 77 | # set label |
| 78 | labels = torch.zeros(bsz, dtype=torch.long).cuda() |
| 79 | |
| 80 | # update memory |
| 81 | all_k = all_k if all_k is not None else k |
| 82 | self._update_memory(all_k, self.memory) |
| 83 | self._update_pointer(all_k.size(0)) |
| 84 | |
| 85 | if q_jig is not None: |
| 86 | return logits, logits_jig, labels |
| 87 | else: |
| 88 | return logits, labels |
| 89 | |
| 90 | |
| 91 | class CMCMoCo(BaseMoCo): |
nothing calls this directly
no test coverage detected