(self, x: Tensor, emb: Optional[Tensor] = None)
| 255 | ) |
| 256 | |
| 257 | def forward(self, x: Tensor, emb: Optional[Tensor] = None): |
| 258 | emb = self.linear(self.silu(emb)) |
| 259 | shift_msa, scale_msa, gate_msa = chunk(emb, 3, dim=1) |
| 260 | x = self.norm(x) * (1 + unsqueeze(scale_msa, 1)) + unsqueeze( |
| 261 | shift_msa, 1) |
| 262 | return x, gate_msa |
| 263 | |
| 264 | |
| 265 | class AdaLayerNormContinuous(Module): |