| 80 | |
| 81 | # feedforward |
| 82 | class GEGLU(nn.Module): |
| 83 | def __init__(self, dim_in, dim_out): |
| 84 | super().__init__() |
| 85 | self.proj = nn.Linear(dim_in, dim_out * 2) |
| 86 | |
| 87 | def forward(self, x): |
| 88 | x, gate = self.proj(x).chunk(2, dim=-1) |
| 89 | return x * F.gelu(gate) |
| 90 | |
| 91 | |
| 92 | class FeedForward(nn.Module): |