(self, emb)
| 21 | super(Elementwise, self).__init__(*args) |
| 22 | |
| 23 | def forward(self, emb): |
| 24 | emb_ = [feat.squeeze(2) for feat in emb.split(1, dim=2)] |
| 25 | emb_out = [] |
| 26 | # for some reason list comprehension is slower in this scenario |
| 27 | for f, x in zip(self, emb_): |
| 28 | emb_out.append(f(x)) |
| 29 | if self.merge == "first": |
| 30 | return emb_out[0] |
| 31 | elif self.merge == "concat" or self.merge == "mlp": |
| 32 | return torch.cat(emb_out, 2) |
| 33 | elif self.merge == "sum": |
| 34 | return sum(emb_out) |
| 35 | else: |
| 36 | return emb_out |