| 3 | from .Model import Model |
| 4 | |
| 5 | class ComplEx(Model): |
| 6 | def __init__(self, ent_tot, rel_tot, dim = 100): |
| 7 | super(ComplEx, self).__init__(ent_tot, rel_tot) |
| 8 | |
| 9 | self.dim = dim |
| 10 | self.ent_re_embeddings = nn.Embedding(self.ent_tot, self.dim) |
| 11 | self.ent_im_embeddings = nn.Embedding(self.ent_tot, self.dim) |
| 12 | self.rel_re_embeddings = nn.Embedding(self.rel_tot, self.dim) |
| 13 | self.rel_im_embeddings = nn.Embedding(self.rel_tot, self.dim) |
| 14 | |
| 15 | nn.init.xavier_uniform_(self.ent_re_embeddings.weight.data) |
| 16 | nn.init.xavier_uniform_(self.ent_im_embeddings.weight.data) |
| 17 | nn.init.xavier_uniform_(self.rel_re_embeddings.weight.data) |
| 18 | nn.init.xavier_uniform_(self.rel_im_embeddings.weight.data) |
| 19 | |
| 20 | def _calc(self, h_re, h_im, t_re, t_im, r_re, r_im): |
| 21 | return torch.sum( |
| 22 | h_re * t_re * r_re |
| 23 | + h_im * t_im * r_re |
| 24 | + h_re * t_im * r_im |
| 25 | - h_im * t_re * r_im, |
| 26 | -1 |
| 27 | ) |
| 28 | |
| 29 | def forward(self, data): |
| 30 | batch_h = data['batch_h'] |
| 31 | batch_t = data['batch_t'] |
| 32 | batch_r = data['batch_r'] |
| 33 | h_re = self.ent_re_embeddings(batch_h) |
| 34 | h_im = self.ent_im_embeddings(batch_h) |
| 35 | t_re = self.ent_re_embeddings(batch_t) |
| 36 | t_im = self.ent_im_embeddings(batch_t) |
| 37 | r_re = self.rel_re_embeddings(batch_r) |
| 38 | r_im = self.rel_im_embeddings(batch_r) |
| 39 | score = self._calc(h_re, h_im, t_re, t_im, r_re, r_im) |
| 40 | return score |
| 41 | |
| 42 | def regularization(self, data): |
| 43 | batch_h = data['batch_h'] |
| 44 | batch_t = data['batch_t'] |
| 45 | batch_r = data['batch_r'] |
| 46 | h_re = self.ent_re_embeddings(batch_h) |
| 47 | h_im = self.ent_im_embeddings(batch_h) |
| 48 | t_re = self.ent_re_embeddings(batch_t) |
| 49 | t_im = self.ent_im_embeddings(batch_t) |
| 50 | r_re = self.rel_re_embeddings(batch_r) |
| 51 | r_im = self.rel_im_embeddings(batch_r) |
| 52 | regul = (torch.mean(h_re ** 2) + |
| 53 | torch.mean(h_im ** 2) + |
| 54 | torch.mean(t_re ** 2) + |
| 55 | torch.mean(t_im ** 2) + |
| 56 | torch.mean(r_re ** 2) + |
| 57 | torch.mean(r_im ** 2)) / 6 |
| 58 | return regul |
| 59 | |
| 60 | def predict(self, data): |
| 61 | score = -self.forward(data) |
| 62 | return score.cpu().data.numpy() |
no outgoing calls
no test coverage detected