MCPcopy
hub / github.com/THUDM/CogDL / forward

Method forward

cogdl/wrappers/tools/memory_moco.py:24–61  ·  view source on GitHub ↗
(self, q, k)

Source from the content-addressed store, hash-verified

22 print("using queue shape: ({},{})".format(self.queueSize, inputSize))
23
24 def forward(self, q, k):
25 batchSize = q.shape[0]
26 k = k.detach()
27
28 Z = self.params[0].item()
29
30 # pos logit
31 l_pos = torch.bmm(q.view(batchSize, 1, -1), k.view(batchSize, -1, 1))
32 l_pos = l_pos.view(batchSize, 1)
33 # neg logit
34 queue = self.memory.clone()
35 l_neg = torch.mm(queue.detach(), q.transpose(1, 0))
36 l_neg = l_neg.transpose(0, 1)
37
38 out = torch.cat((l_pos, l_neg), dim=1)
39
40 if self.use_softmax:
41 out = torch.div(out, self.T)
42 out = out.squeeze().contiguous()
43 else:
44 out = torch.exp(torch.div(out, self.T))
45 if Z < 0:
46 self.params[0] = out.mean() * self.outputSize
47 Z = self.params[0].clone().detach().item()
48 print("normalization constant Z is set to {:.1f}".format(Z))
49 # compute the out
50 out = torch.div(out, Z).squeeze().contiguous()
51
52 # # update memory
53 with torch.no_grad():
54 out_ids = torch.arange(batchSize, device=out.device)
55 out_ids += self.index
56 out_ids = torch.fmod(out_ids, self.queueSize)
57 out_ids = out_ids.long()
58 self.memory.index_copy_(0, out_ids, k)
59 self.index = (self.index + batchSize) % self.queueSize
60
61 return out
62
63
64class NCESoftmaxLoss(nn.Module):

Callers

nothing calls this directly

Calls 2

contiguousMethod · 0.80
cloneMethod · 0.45

Tested by

no test coverage detected