MCPcopy Index your code
hub / github.com/HobbitLong/PyContrast / forward

Method forward

pycontrast/memory/mem_bank.py:107–154  ·  view source on GitHub ↗

Args: x1: feat of modal 1 x2: feat of modal 2 y: index on current node x1_jig: jigsaw feat of modal1 x2_jig: jigsaw feat of modal2 all_x1: gather of feats across nodes; otherwise use x1 all_x2: gather of feats across node

(self, x1, x2, y, x1_jig=None, x2_jig=None,
                all_x1=None, all_x2=None, all_y=None)

Source from the content-addressed store, hash-verified

105 self.memory_2 = F.normalize(self.memory_2)
106
107 def forward(self, x1, x2, y, x1_jig=None, x2_jig=None,
108 all_x1=None, all_x2=None, all_y=None):
109 """
110 Args:
111 x1: feat of modal 1
112 x2: feat of modal 2
113 y: index on current node
114 x1_jig: jigsaw feat of modal1
115 x2_jig: jigsaw feat of modal2
116 all_x1: gather of feats across nodes; otherwise use x1
117 all_x2: gather of feats across nodes; otherwise use x2
118 all_y: gather of index across nodes; otherwise use y
119 """
120 bsz = x1.size(0)
121 n_dim = x1.size(1)
122
123 # sample negative features
124 idx = self.multinomial.draw(bsz * (self.K + 1)).view(bsz, -1)
125 idx.select(1, 0).copy_(y.data)
126
127 w1 = torch.index_select(self.memory_1, 0, idx.view(-1))
128 w1 = w1.view(bsz, self.K + 1, n_dim)
129 w2 = torch.index_select(self.memory_2, 0, idx.view(-1))
130 w2 = w2.view(bsz, self.K + 1, n_dim)
131
132 # compute logits
133 logits1 = self._compute_logit(x1, w2)
134 logits2 = self._compute_logit(x2, w1)
135 if (x1_jig is not None) and (x2_jig is not None):
136 logits1_jig = self._compute_logit(x1_jig, w2)
137 logits2_jig = self._compute_logit(x2_jig, w1)
138
139 # set label
140 labels = torch.zeros(bsz, dtype=torch.long).cuda()
141
142 # update memory
143 if (all_x1 is not None) and (all_x2 is not None) \
144 and (all_y is not None):
145 self._update_memory(self.memory_1, all_x1, all_y)
146 self._update_memory(self.memory_2, all_x2, all_y)
147 else:
148 self._update_memory(self.memory_1, x1, y)
149 self._update_memory(self.memory_2, x2, y)
150
151 if (x1_jig is not None) and (x2_jig is not None):
152 return logits1, logits2, logits1_jig, logits2_jig, labels
153 else:
154 return logits1, logits2, labels

Callers

nothing calls this directly

Calls 4

drawMethod · 0.80
cudaMethod · 0.80
_compute_logitMethod · 0.45
_update_memoryMethod · 0.45

Tested by

no test coverage detected