(self, dim)
| 226 | """ |
| 227 | |
| 228 | def __init__(self, dim): |
| 229 | super().__init__() |
| 230 | self.dim = dim |
| 231 | |
| 232 | # layers |
| 233 | self.norm = RMS_norm(dim) |
| 234 | self.to_qkv = nn.Conv2d(dim, dim * 3, 1) |
| 235 | self.proj = nn.Conv2d(dim, dim, 1) |
| 236 | |
| 237 | # zero out the last layer params |
| 238 | nn.init.zeros_(self.proj.weight) |
| 239 | |
| 240 | def forward(self, x): |
| 241 | identity = x |