MCPcopy
hub / github.com/thunlp/OpenKE / DistMult

Class DistMult

openke/module/model/DistMult.py:5–72  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

3from .Model import Model
4
5class DistMult(Model):
6
7 def __init__(self, ent_tot, rel_tot, dim = 100, margin = None, epsilon = None):
8 super(DistMult, self).__init__(ent_tot, rel_tot)
9
10 self.dim = dim
11 self.margin = margin
12 self.epsilon = epsilon
13 self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
14 self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)
15
16 if margin == None or epsilon == None:
17 nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
18 nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
19 else:
20 self.embedding_range = nn.Parameter(
21 torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False
22 )
23 nn.init.uniform_(
24 tensor = self.ent_embeddings.weight.data,
25 a = -self.embedding_range.item(),
26 b = self.embedding_range.item()
27 )
28 nn.init.uniform_(
29 tensor = self.rel_embeddings.weight.data,
30 a= -self.embedding_range.item(),
31 b= self.embedding_range.item()
32 )
33
34 def _calc(self, h, t, r, mode):
35 if mode != 'normal':
36 h = h.view(-1, r.shape[0], h.shape[-1])
37 t = t.view(-1, r.shape[0], t.shape[-1])
38 r = r.view(-1, r.shape[0], r.shape[-1])
39 if mode == 'head_batch':
40 score = h * (r * t)
41 else:
42 score = (h * r) * t
43 score = torch.sum(score, -1).flatten()
44 return score
45
46 def forward(self, data):
47 batch_h = data['batch_h']
48 batch_t = data['batch_t']
49 batch_r = data['batch_r']
50 mode = data['mode']
51 h = self.ent_embeddings(batch_h)
52 t = self.ent_embeddings(batch_t)
53 r = self.rel_embeddings(batch_r)
54 score = self._calc(h ,t, r, mode)
55 return score
56
57 def regularization(self, data):
58 batch_h = data['batch_h']
59 batch_t = data['batch_t']
60 batch_r = data['batch_r']
61 h = self.ent_embeddings(batch_h)
62 t = self.ent_embeddings(batch_t)

Calls

no outgoing calls

Tested by

no test coverage detected