(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0)
| 91 | |
| 92 | class FeedForward(nn.Module): |
| 93 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): |
| 94 | super().__init__() |
| 95 | inner_dim = int(dim * mult) |
| 96 | dim_out = default(dim_out, dim) |
| 97 | project_in = ( |
| 98 | nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) |
| 99 | if not glu |
| 100 | else GEGLU(dim, inner_dim) |
| 101 | ) |
| 102 | |
| 103 | self.net = nn.Sequential( |
| 104 | project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) |
| 105 | ) |
| 106 | |
| 107 | def forward(self, x): |
| 108 | return self.net(x) |