MCPcopy
hub / github.com/thunlp/OpenKE / SimplE

Class SimplE

openke/module/model/SimplE.py:5–55  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

3from .Model import Model
4
5class SimplE(Model):
6
7 def __init__(self, ent_tot, rel_tot, dim = 100):
8 super(SimplE, self).__init__(ent_tot, rel_tot)
9
10 self.dim = dim
11 self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
12 self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)
13 self.rel_inv_embeddings = nn.Embedding(self.rel_tot, self.dim)
14
15 nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
16 nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
17 nn.init.xavier_uniform_(self.rel_inv_embeddings.weight.data)
18
19 def _calc_avg(self, h, t, r, r_inv):
20 return (torch.sum(h * r * t, -1) + torch.sum(h * r_inv * t, -1))/2
21
22 def _calc_ingr(self, h, r, t):
23 return torch.sum(h * r * t, -1)
24
25 def forward(self, data):
26 batch_h = data['batch_h']
27 batch_t = data['batch_t']
28 batch_r = data['batch_r']
29 h = self.ent_embeddings(batch_h)
30 t = self.ent_embeddings(batch_t)
31 r = self.rel_embeddings(batch_r)
32 r_inv = self.rel_inv_embeddings(batch_r)
33 score = self._calc_avg(h, t, r, r_inv)
34 return score
35
36 def regularization(self, data):
37 batch_h = data['batch_h']
38 batch_t = data['batch_t']
39 batch_r = data['batch_r']
40 h = self.ent_embeddings(batch_h)
41 t = self.ent_embeddings(batch_t)
42 r = self.rel_embeddings(batch_r)
43 r_inv = self.rel_inv_embeddings(batch_r)
44 regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2) + torch.mean(r_inv ** 2)) / 4
45 return regul
46
47 def predict(self, data):
48 batch_h = data['batch_h']
49 batch_t = data['batch_t']
50 batch_r = data['batch_r']
51 h = self.ent_embeddings(batch_h)
52 t = self.ent_embeddings(batch_t)
53 r = self.rel_embeddings(batch_r)
54 score = -self._calc_ingr(h, r, t)
55 return score.cpu().data.numpy()

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected