MCPcopy
hub / github.com/OpenNMT/OpenNMT-py / forward

Method forward

onmt/modules/util_class.py:23–36  ·  view source on GitHub ↗
(self, emb)

Source from the content-addressed store, hash-verified

21 super(Elementwise, self).__init__(*args)
22
23 def forward(self, emb):
24 emb_ = [feat.squeeze(2) for feat in emb.split(1, dim=2)]
25 emb_out = []
26 # for some reason list comprehension is slower in this scenario
27 for f, x in zip(self, emb_):
28 emb_out.append(f(x))
29 if self.merge == "first":
30 return emb_out[0]
31 elif self.merge == "concat" or self.merge == "mlp":
32 return torch.cat(emb_out, 2)
33 elif self.merge == "sum":
34 return sum(emb_out)
35 else:
36 return emb_out

Callers

nothing calls this directly

Calls 1

squeezeMethod · 0.80

Tested by

no test coverage detected