Forward pass for the gating mechanism. Args: x (torch.Tensor): Input tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
(self, x: torch.Tensor)
| 561 | self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None |
| 562 | |
| 563 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 564 | """ |
| 565 | Forward pass for the gating mechanism. |
| 566 | |
| 567 | Args: |
| 568 | x (torch.Tensor): Input tensor. |
| 569 | |
| 570 | Returns: |
| 571 | Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. |
| 572 | """ |
| 573 | scores = linear(x, self.weight) |
| 574 | if self.score_func == "softmax": |
| 575 | scores = scores.softmax(dim=-1, dtype=torch.float32) |
| 576 | else: |
| 577 | scores = scores.sigmoid() |
| 578 | original_scores = scores |
| 579 | if self.bias is not None: |
| 580 | scores = scores + self.bias |
| 581 | if self.n_groups > 1: |
| 582 | scores = scores.view(x.size(0), self.n_groups, -1) |
| 583 | if self.bias is None: |
| 584 | group_scores = scores.amax(dim=-1) |
| 585 | else: |
| 586 | group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) |
| 587 | indices = group_scores.topk(self.topk_groups, dim=-1)[1] |
| 588 | mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) |
| 589 | scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) |
| 590 | indices = torch.topk(scores, self.topk, dim=-1)[1] |
| 591 | weights = original_scores.gather(1, indices) |
| 592 | if self.score_func == "sigmoid": |
| 593 | weights /= weights.sum(dim=-1, keepdim=True) |
| 594 | weights *= self.route_scale |
| 595 | return weights.type_as(x), indices |
| 596 | |
| 597 | |
| 598 | class Expert(nn.Module): |