The Cross Network part of DCN-Mix model, which improves DCN-M by: 1 add MOE to learn feature interactions in different subspaces 2 add nonlinear transformations in low-dimensional space Input shape - 2D tensor with shape: ``(batch_size, units)``. Output shape
| 454 | |
| 455 | |
| 456 | class CrossNetMix(nn.Module): |
| 457 | """The Cross Network part of DCN-Mix model, which improves DCN-M by: |
| 458 | 1 add MOE to learn feature interactions in different subspaces |
| 459 | 2 add nonlinear transformations in low-dimensional space |
| 460 | Input shape |
| 461 | - 2D tensor with shape: ``(batch_size, units)``. |
| 462 | Output shape |
| 463 | - 2D tensor with shape: ``(batch_size, units)``. |
| 464 | Arguments |
| 465 | - **in_features** : Positive integer, dimensionality of input features. |
| 466 | - **low_rank** : Positive integer, dimensionality of low-rank sapce. |
| 467 | - **num_experts** : Positive integer, number of experts. |
| 468 | - **layer_num**: Positive integer, the cross layer number |
| 469 | - **device**: str, e.g. ``"cpu"`` or ``"cuda:0"`` |
| 470 | References |
| 471 | - [Wang R, Shivanna R, Cheng D Z, et al. DCN-M: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems[J]. 2020.](https://arxiv.org/abs/2008.13535) |
| 472 | """ |
| 473 | |
| 474 | def __init__(self, in_features, low_rank=32, num_experts=4, layer_num=2, device='cpu'): |
| 475 | super(CrossNetMix, self).__init__() |
| 476 | self.layer_num = layer_num |
| 477 | self.num_experts = num_experts |
| 478 | |
| 479 | # U: (in_features, low_rank) |
| 480 | self.U_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank)) |
| 481 | # V: (in_features, low_rank) |
| 482 | self.V_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank)) |
| 483 | # C: (low_rank, low_rank) |
| 484 | self.C_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, low_rank, low_rank)) |
| 485 | self.gating = nn.ModuleList([nn.Linear(in_features, 1, bias=False) for i in range(self.num_experts)]) |
| 486 | |
| 487 | self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1)) |
| 488 | |
| 489 | init_para_list = [self.U_list, self.V_list, self.C_list] |
| 490 | for para in init_para_list: |
| 491 | for i in range(self.layer_num): |
| 492 | nn.init.xavier_normal_(para[i]) |
| 493 | |
| 494 | for i in range(len(self.bias)): |
| 495 | nn.init.zeros_(self.bias[i]) |
| 496 | |
| 497 | self.to(device) |
| 498 | |
| 499 | def forward(self, inputs): |
| 500 | x_0 = inputs.unsqueeze(2) # (bs, in_features, 1) |
| 501 | x_l = x_0 |
| 502 | for i in range(self.layer_num): |
| 503 | output_of_experts = [] |
| 504 | gating_score_of_experts = [] |
| 505 | for expert_id in range(self.num_experts): |
| 506 | # (1) G(x_l) |
| 507 | # compute the gating score by x_l |
| 508 | gating_score_of_experts.append(self.gating[expert_id](x_l.squeeze(2))) |
| 509 | |
| 510 | # (2) E(x_l) |
| 511 | # project the input x_l to $\mathbb{R}^{r}$ |
| 512 | v_x = torch.matmul(self.V_list[i][expert_id].t(), x_l) # (bs, low_rank, 1) |
| 513 |