| 330 | ) |
| 331 | |
| 332 | def forward(self, hidden_states: Tensor, emb: Tensor): |
| 333 | emb = self.linear(self.silu(emb).cast(hidden_states.dtype)) |
| 334 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = chunk( |
| 335 | emb, 9, dim=1) |
| 336 | norm_hidden_states = self.norm(hidden_states) |
| 337 | hidden_states = norm_hidden_states * ( |
| 338 | 1 + unsqueeze(scale_msa, 1)) + unsqueeze(shift_msa, 1) |
| 339 | norm_hidden_states2 = norm_hidden_states * ( |
| 340 | 1 + unsqueeze(scale_msa2, 1)) + unsqueeze(shift_msa2, 1) |
| 341 | return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 |