(self, x: Tensor, *, cond: Tensor)
| 307 | |
| 308 | @beartype |
| 309 | def forward(self, x: Tensor, *, cond: Tensor): |
| 310 | batch = x.shape[0] |
| 311 | assert cond.shape == (batch, self.dim_cond) |
| 312 | |
| 313 | gamma = self.to_gamma(cond) |
| 314 | |
| 315 | bias = 0.0 |
| 316 | if exists(self.to_bias): |
| 317 | bias = self.to_bias(cond) |
| 318 | |
| 319 | if self.channel_first: |
| 320 | gamma = append_dims(gamma, x.ndim - 2) |
| 321 | |
| 322 | if exists(self.to_bias): |
| 323 | bias = append_dims(bias, x.ndim - 2) |
| 324 | |
| 325 | return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * gamma + bias |
| 326 | |
| 327 | |
| 328 | # attention |
nothing calls this directly
no test coverage detected