(self, x: Tensor, conditioning_embedding: Tensor)
| 296 | raise ValueError(f"unknown norm_type {norm_type}") |
| 297 | |
| 298 | def forward(self, x: Tensor, conditioning_embedding: Tensor): |
| 299 | # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) |
| 300 | emb = self.linear(self.silu(conditioning_embedding).cast(x.dtype)) |
| 301 | scale, shift = chunk(emb, 2, dim=1) |
| 302 | x = self.norm(x) * unsqueeze((1 + scale), 1) + unsqueeze(shift, 1) |
| 303 | return x |
| 304 | |
| 305 | |
| 306 | class SD35AdaLayerNormZeroX(Module): |