MCPcopy
hub / github.com/shenweichen/DeepCTR-Torch / CrossNetMix

Class CrossNetMix

deepctr_torch/layers/interaction.py:456–534  ·  view source on GitHub ↗

The Cross Network part of DCN-Mix model, which improves DCN-M by: 1 add MOE to learn feature interactions in different subspaces 2 add nonlinear transformations in low-dimensional space Input shape - 2D tensor with shape: ``(batch_size, units)``. Output shape

Source from the content-addressed store, hash-verified

454
455
456class CrossNetMix(nn.Module):
457 """The Cross Network part of DCN-Mix model, which improves DCN-M by:
458 1 add MOE to learn feature interactions in different subspaces
459 2 add nonlinear transformations in low-dimensional space
460 Input shape
461 - 2D tensor with shape: ``(batch_size, units)``.
462 Output shape
463 - 2D tensor with shape: ``(batch_size, units)``.
464 Arguments
465 - **in_features** : Positive integer, dimensionality of input features.
466 - **low_rank** : Positive integer, dimensionality of low-rank sapce.
467 - **num_experts** : Positive integer, number of experts.
468 - **layer_num**: Positive integer, the cross layer number
469 - **device**: str, e.g. ``"cpu"`` or ``"cuda:0"``
470 References
471 - [Wang R, Shivanna R, Cheng D Z, et al. DCN-M: Improved Deep & Cross Network for Feature Cross Learning in Web-scale Learning to Rank Systems[J]. 2020.](https://arxiv.org/abs/2008.13535)
472 """
473
474 def __init__(self, in_features, low_rank=32, num_experts=4, layer_num=2, device='cpu'):
475 super(CrossNetMix, self).__init__()
476 self.layer_num = layer_num
477 self.num_experts = num_experts
478
479 # U: (in_features, low_rank)
480 self.U_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank))
481 # V: (in_features, low_rank)
482 self.V_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, in_features, low_rank))
483 # C: (low_rank, low_rank)
484 self.C_list = nn.Parameter(torch.Tensor(self.layer_num, num_experts, low_rank, low_rank))
485 self.gating = nn.ModuleList([nn.Linear(in_features, 1, bias=False) for i in range(self.num_experts)])
486
487 self.bias = nn.Parameter(torch.Tensor(self.layer_num, in_features, 1))
488
489 init_para_list = [self.U_list, self.V_list, self.C_list]
490 for para in init_para_list:
491 for i in range(self.layer_num):
492 nn.init.xavier_normal_(para[i])
493
494 for i in range(len(self.bias)):
495 nn.init.zeros_(self.bias[i])
496
497 self.to(device)
498
499 def forward(self, inputs):
500 x_0 = inputs.unsqueeze(2) # (bs, in_features, 1)
501 x_l = x_0
502 for i in range(self.layer_num):
503 output_of_experts = []
504 gating_score_of_experts = []
505 for expert_id in range(self.num_experts):
506 # (1) G(x_l)
507 # compute the gating score by x_l
508 gating_score_of_experts.append(self.gating[expert_id](x_l.squeeze(2)))
509
510 # (2) E(x_l)
511 # project the input x_l to $\mathbb{R}^{r}$
512 v_x = torch.matmul(self.V_list[i][expert_id].t(), x_l) # (bs, low_rank, 1)
513

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected