(self, q, k)
| 22 | print("using queue shape: ({},{})".format(self.queueSize, inputSize)) |
| 23 | |
| 24 | def forward(self, q, k): |
| 25 | batchSize = q.shape[0] |
| 26 | k = k.detach() |
| 27 | |
| 28 | Z = self.params[0].item() |
| 29 | |
| 30 | # pos logit |
| 31 | l_pos = torch.bmm(q.view(batchSize, 1, -1), k.view(batchSize, -1, 1)) |
| 32 | l_pos = l_pos.view(batchSize, 1) |
| 33 | # neg logit |
| 34 | queue = self.memory.clone() |
| 35 | l_neg = torch.mm(queue.detach(), q.transpose(1, 0)) |
| 36 | l_neg = l_neg.transpose(0, 1) |
| 37 | |
| 38 | out = torch.cat((l_pos, l_neg), dim=1) |
| 39 | |
| 40 | if self.use_softmax: |
| 41 | out = torch.div(out, self.T) |
| 42 | out = out.squeeze().contiguous() |
| 43 | else: |
| 44 | out = torch.exp(torch.div(out, self.T)) |
| 45 | if Z < 0: |
| 46 | self.params[0] = out.mean() * self.outputSize |
| 47 | Z = self.params[0].clone().detach().item() |
| 48 | print("normalization constant Z is set to {:.1f}".format(Z)) |
| 49 | # compute the out |
| 50 | out = torch.div(out, Z).squeeze().contiguous() |
| 51 | |
| 52 | # # update memory |
| 53 | with torch.no_grad(): |
| 54 | out_ids = torch.arange(batchSize, device=out.device) |
| 55 | out_ids += self.index |
| 56 | out_ids = torch.fmod(out_ids, self.queueSize) |
| 57 | out_ids = out_ids.long() |
| 58 | self.memory.index_copy_(0, out_ids, k) |
| 59 | self.index = (self.index + batchSize) % self.queueSize |
| 60 | |
| 61 | return out |
| 62 | |
| 63 | |
| 64 | class NCESoftmaxLoss(nn.Module): |
nothing calls this directly
no test coverage detected