(self, sel:Tensor, x:Tensor)
| 16 | def __init__(self, num_experts:int, in_features:int, out_features:int): |
| 17 | self.weight = Tensor.zeros(num_experts, out_features, in_features) |
| 18 | def __call__(self, sel:Tensor, x:Tensor) -> Tensor: |
| 19 | # sel: (B, T, k), x: (B, T, 1, in) or (B, T, k, in) -> output: (B, T, k, out) |
| 20 | return (x.unsqueeze(-2) @ self.weight[sel].transpose(-1, -2)).contiguous().squeeze(-2) |
| 21 | |
| 22 | def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor: |
| 23 | assert x.shape[-1] % 2 == 0 |
nothing calls this directly
no test coverage detected