| 155 | dtype=dtype) |
| 156 | |
| 157 | def forward(self, |
| 158 | x: Tensor, |
| 159 | timestep: Optional[Tensor] = None, |
| 160 | temb: Optional[Tensor] = None): |
| 161 | assert timestep is not None or temb is not None |
| 162 | if self.emb is not None and timestep is not None: |
| 163 | temb = self.emb(timestep) |
| 164 | temb = self.linear(self.silu(temb)) |
| 165 | if self.chunk_dim == 1: |
| 166 | shift, scale = chunk(temb, 2, dim=1) |
| 167 | shift = unsqueeze(shift, 1) |
| 168 | scale = unsqueeze(scale, 1) |
| 169 | else: |
| 170 | scale, shift = chunk(temb, 2, dim=0) |
| 171 | x = self.norm(x) * (1 + scale) + shift |
| 172 | return x |
| 173 | |
| 174 | |
| 175 | class AdaLayerNormZero(Module): |