MCPcopy
hub / github.com/thunlp/OpenKE / ComplEx

Class ComplEx

openke/module/model/ComplEx.py:5–62  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

3from .Model import Model
4
5class ComplEx(Model):
6 def __init__(self, ent_tot, rel_tot, dim = 100):
7 super(ComplEx, self).__init__(ent_tot, rel_tot)
8
9 self.dim = dim
10 self.ent_re_embeddings = nn.Embedding(self.ent_tot, self.dim)
11 self.ent_im_embeddings = nn.Embedding(self.ent_tot, self.dim)
12 self.rel_re_embeddings = nn.Embedding(self.rel_tot, self.dim)
13 self.rel_im_embeddings = nn.Embedding(self.rel_tot, self.dim)
14
15 nn.init.xavier_uniform_(self.ent_re_embeddings.weight.data)
16 nn.init.xavier_uniform_(self.ent_im_embeddings.weight.data)
17 nn.init.xavier_uniform_(self.rel_re_embeddings.weight.data)
18 nn.init.xavier_uniform_(self.rel_im_embeddings.weight.data)
19
20 def _calc(self, h_re, h_im, t_re, t_im, r_re, r_im):
21 return torch.sum(
22 h_re * t_re * r_re
23 + h_im * t_im * r_re
24 + h_re * t_im * r_im
25 - h_im * t_re * r_im,
26 -1
27 )
28
29 def forward(self, data):
30 batch_h = data['batch_h']
31 batch_t = data['batch_t']
32 batch_r = data['batch_r']
33 h_re = self.ent_re_embeddings(batch_h)
34 h_im = self.ent_im_embeddings(batch_h)
35 t_re = self.ent_re_embeddings(batch_t)
36 t_im = self.ent_im_embeddings(batch_t)
37 r_re = self.rel_re_embeddings(batch_r)
38 r_im = self.rel_im_embeddings(batch_r)
39 score = self._calc(h_re, h_im, t_re, t_im, r_re, r_im)
40 return score
41
42 def regularization(self, data):
43 batch_h = data['batch_h']
44 batch_t = data['batch_t']
45 batch_r = data['batch_r']
46 h_re = self.ent_re_embeddings(batch_h)
47 h_im = self.ent_im_embeddings(batch_h)
48 t_re = self.ent_re_embeddings(batch_t)
49 t_im = self.ent_im_embeddings(batch_t)
50 r_re = self.rel_re_embeddings(batch_r)
51 r_im = self.rel_im_embeddings(batch_r)
52 regul = (torch.mean(h_re ** 2) +
53 torch.mean(h_im ** 2) +
54 torch.mean(t_re ** 2) +
55 torch.mean(t_im ** 2) +
56 torch.mean(r_re ** 2) +
57 torch.mean(r_im ** 2)) / 6
58 return regul
59
60 def predict(self, data):
61 score = -self.forward(data)
62 return score.cpu().data.numpy()

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected