Args: x1: feat of modal 1 x2: feat of modal 2 y: index on current node x1_jig: jigsaw feat of modal1 x2_jig: jigsaw feat of modal2 all_x1: gather of feats across nodes; otherwise use x1 all_x2: gather of feats across node
(self, x1, x2, y, x1_jig=None, x2_jig=None,
all_x1=None, all_x2=None, all_y=None)
| 105 | self.memory_2 = F.normalize(self.memory_2) |
| 106 | |
| 107 | def forward(self, x1, x2, y, x1_jig=None, x2_jig=None, |
| 108 | all_x1=None, all_x2=None, all_y=None): |
| 109 | """ |
| 110 | Args: |
| 111 | x1: feat of modal 1 |
| 112 | x2: feat of modal 2 |
| 113 | y: index on current node |
| 114 | x1_jig: jigsaw feat of modal1 |
| 115 | x2_jig: jigsaw feat of modal2 |
| 116 | all_x1: gather of feats across nodes; otherwise use x1 |
| 117 | all_x2: gather of feats across nodes; otherwise use x2 |
| 118 | all_y: gather of index across nodes; otherwise use y |
| 119 | """ |
| 120 | bsz = x1.size(0) |
| 121 | n_dim = x1.size(1) |
| 122 | |
| 123 | # sample negative features |
| 124 | idx = self.multinomial.draw(bsz * (self.K + 1)).view(bsz, -1) |
| 125 | idx.select(1, 0).copy_(y.data) |
| 126 | |
| 127 | w1 = torch.index_select(self.memory_1, 0, idx.view(-1)) |
| 128 | w1 = w1.view(bsz, self.K + 1, n_dim) |
| 129 | w2 = torch.index_select(self.memory_2, 0, idx.view(-1)) |
| 130 | w2 = w2.view(bsz, self.K + 1, n_dim) |
| 131 | |
| 132 | # compute logits |
| 133 | logits1 = self._compute_logit(x1, w2) |
| 134 | logits2 = self._compute_logit(x2, w1) |
| 135 | if (x1_jig is not None) and (x2_jig is not None): |
| 136 | logits1_jig = self._compute_logit(x1_jig, w2) |
| 137 | logits2_jig = self._compute_logit(x2_jig, w1) |
| 138 | |
| 139 | # set label |
| 140 | labels = torch.zeros(bsz, dtype=torch.long).cuda() |
| 141 | |
| 142 | # update memory |
| 143 | if (all_x1 is not None) and (all_x2 is not None) \ |
| 144 | and (all_y is not None): |
| 145 | self._update_memory(self.memory_1, all_x1, all_y) |
| 146 | self._update_memory(self.memory_2, all_x2, all_y) |
| 147 | else: |
| 148 | self._update_memory(self.memory_1, x1, y) |
| 149 | self._update_memory(self.memory_2, x2, y) |
| 150 | |
| 151 | if (x1_jig is not None) and (x2_jig is not None): |
| 152 | return logits1, logits2, logits1_jig, logits2_jig, labels |
| 153 | else: |
| 154 | return logits1, logits2, labels |
nothing calls this directly
no test coverage detected