| 497 | self.to(device) |
| 498 | |
| 499 | def forward(self, inputs): |
| 500 | x_0 = inputs.unsqueeze(2) # (bs, in_features, 1) |
| 501 | x_l = x_0 |
| 502 | for i in range(self.layer_num): |
| 503 | output_of_experts = [] |
| 504 | gating_score_of_experts = [] |
| 505 | for expert_id in range(self.num_experts): |
| 506 | # (1) G(x_l) |
| 507 | # compute the gating score by x_l |
| 508 | gating_score_of_experts.append(self.gating[expert_id](x_l.squeeze(2))) |
| 509 | |
| 510 | # (2) E(x_l) |
| 511 | # project the input x_l to $\mathbb{R}^{r}$ |
| 512 | v_x = torch.matmul(self.V_list[i][expert_id].t(), x_l) # (bs, low_rank, 1) |
| 513 | |
| 514 | # nonlinear activation in low rank space |
| 515 | v_x = torch.tanh(v_x) |
| 516 | v_x = torch.matmul(self.C_list[i][expert_id], v_x) |
| 517 | v_x = torch.tanh(v_x) |
| 518 | |
| 519 | # project back to $\mathbb{R}^{d}$ |
| 520 | uv_x = torch.matmul(self.U_list[i][expert_id], v_x) # (bs, in_features, 1) |
| 521 | |
| 522 | dot_ = uv_x + self.bias[i] |
| 523 | dot_ = x_0 * dot_ # Hadamard-product |
| 524 | |
| 525 | output_of_experts.append(dot_.squeeze(2)) |
| 526 | |
| 527 | # (3) mixture of low-rank experts |
| 528 | output_of_experts = torch.stack(output_of_experts, 2) # (bs, in_features, num_experts) |
| 529 | gating_score_of_experts = torch.stack(gating_score_of_experts, 1) # (bs, num_experts, 1) |
| 530 | moe_out = torch.matmul(output_of_experts, gating_score_of_experts.softmax(1)) |
| 531 | x_l = moe_out + x_l # (bs, in_features, 1) |
| 532 | |
| 533 | x_l = x_l.squeeze() # (bs, in_features) |
| 534 | return x_l |
| 535 | |
| 536 | |
| 537 | class InnerProductLayer(nn.Module): |