| 3 | from .Model import Model |
| 4 | |
| 5 | class SimplE(Model): |
| 6 | |
| 7 | def __init__(self, ent_tot, rel_tot, dim = 100): |
| 8 | super(SimplE, self).__init__(ent_tot, rel_tot) |
| 9 | |
| 10 | self.dim = dim |
| 11 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) |
| 12 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) |
| 13 | self.rel_inv_embeddings = nn.Embedding(self.rel_tot, self.dim) |
| 14 | |
| 15 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) |
| 16 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) |
| 17 | nn.init.xavier_uniform_(self.rel_inv_embeddings.weight.data) |
| 18 | |
| 19 | def _calc_avg(self, h, t, r, r_inv): |
| 20 | return (torch.sum(h * r * t, -1) + torch.sum(h * r_inv * t, -1))/2 |
| 21 | |
| 22 | def _calc_ingr(self, h, r, t): |
| 23 | return torch.sum(h * r * t, -1) |
| 24 | |
| 25 | def forward(self, data): |
| 26 | batch_h = data['batch_h'] |
| 27 | batch_t = data['batch_t'] |
| 28 | batch_r = data['batch_r'] |
| 29 | h = self.ent_embeddings(batch_h) |
| 30 | t = self.ent_embeddings(batch_t) |
| 31 | r = self.rel_embeddings(batch_r) |
| 32 | r_inv = self.rel_inv_embeddings(batch_r) |
| 33 | score = self._calc_avg(h, t, r, r_inv) |
| 34 | return score |
| 35 | |
| 36 | def regularization(self, data): |
| 37 | batch_h = data['batch_h'] |
| 38 | batch_t = data['batch_t'] |
| 39 | batch_r = data['batch_r'] |
| 40 | h = self.ent_embeddings(batch_h) |
| 41 | t = self.ent_embeddings(batch_t) |
| 42 | r = self.rel_embeddings(batch_r) |
| 43 | r_inv = self.rel_inv_embeddings(batch_r) |
| 44 | regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2) + torch.mean(r_inv ** 2)) / 4 |
| 45 | return regul |
| 46 | |
| 47 | def predict(self, data): |
| 48 | batch_h = data['batch_h'] |
| 49 | batch_t = data['batch_t'] |
| 50 | batch_r = data['batch_r'] |
| 51 | h = self.ent_embeddings(batch_h) |
| 52 | t = self.ent_embeddings(batch_t) |
| 53 | r = self.rel_embeddings(batch_r) |
| 54 | score = -self._calc_ingr(h, r, t) |
| 55 | return score.cpu().data.numpy() |
no outgoing calls
no test coverage detected