| 3 | from .Model import Model |
| 4 | |
| 5 | class DistMult(Model): |
| 6 | |
| 7 | def __init__(self, ent_tot, rel_tot, dim = 100, margin = None, epsilon = None): |
| 8 | super(DistMult, self).__init__(ent_tot, rel_tot) |
| 9 | |
| 10 | self.dim = dim |
| 11 | self.margin = margin |
| 12 | self.epsilon = epsilon |
| 13 | self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim) |
| 14 | self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim) |
| 15 | |
| 16 | if margin == None or epsilon == None: |
| 17 | nn.init.xavier_uniform_(self.ent_embeddings.weight.data) |
| 18 | nn.init.xavier_uniform_(self.rel_embeddings.weight.data) |
| 19 | else: |
| 20 | self.embedding_range = nn.Parameter( |
| 21 | torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False |
| 22 | ) |
| 23 | nn.init.uniform_( |
| 24 | tensor = self.ent_embeddings.weight.data, |
| 25 | a = -self.embedding_range.item(), |
| 26 | b = self.embedding_range.item() |
| 27 | ) |
| 28 | nn.init.uniform_( |
| 29 | tensor = self.rel_embeddings.weight.data, |
| 30 | a= -self.embedding_range.item(), |
| 31 | b= self.embedding_range.item() |
| 32 | ) |
| 33 | |
| 34 | def _calc(self, h, t, r, mode): |
| 35 | if mode != 'normal': |
| 36 | h = h.view(-1, r.shape[0], h.shape[-1]) |
| 37 | t = t.view(-1, r.shape[0], t.shape[-1]) |
| 38 | r = r.view(-1, r.shape[0], r.shape[-1]) |
| 39 | if mode == 'head_batch': |
| 40 | score = h * (r * t) |
| 41 | else: |
| 42 | score = (h * r) * t |
| 43 | score = torch.sum(score, -1).flatten() |
| 44 | return score |
| 45 | |
| 46 | def forward(self, data): |
| 47 | batch_h = data['batch_h'] |
| 48 | batch_t = data['batch_t'] |
| 49 | batch_r = data['batch_r'] |
| 50 | mode = data['mode'] |
| 51 | h = self.ent_embeddings(batch_h) |
| 52 | t = self.ent_embeddings(batch_t) |
| 53 | r = self.rel_embeddings(batch_r) |
| 54 | score = self._calc(h ,t, r, mode) |
| 55 | return score |
| 56 | |
| 57 | def regularization(self, data): |
| 58 | batch_h = data['batch_h'] |
| 59 | batch_t = data['batch_t'] |
| 60 | batch_r = data['batch_r'] |
| 61 | h = self.ent_embeddings(batch_h) |
| 62 | t = self.ent_embeddings(batch_t) |
no outgoing calls
no test coverage detected