(self, dim, *, dim_cond: Optional[int] = None, mult=4, images=False)
| 452 | class FeedForward(Module): |
| 453 | @beartype |
| 454 | def __init__(self, dim, *, dim_cond: Optional[int] = None, mult=4, images=False): |
| 455 | super().__init__() |
| 456 | conv_klass = nn.Conv2d if images else nn.Conv3d |
| 457 | |
| 458 | rmsnorm_klass = RMSNorm if not exists(dim_cond) else partial(AdaptiveRMSNorm, dim_cond=dim_cond) |
| 459 | |
| 460 | maybe_adaptive_norm_klass = partial(rmsnorm_klass, channel_first=True, images=images) |
| 461 | |
| 462 | dim_inner = int(dim * mult * 2 / 3) |
| 463 | |
| 464 | self.norm = maybe_adaptive_norm_klass(dim) |
| 465 | |
| 466 | self.net = Sequential(conv_klass(dim, dim_inner * 2, 1), GEGLU(), conv_klass(dim_inner, dim, 1)) |
| 467 | |
| 468 | @beartype |
| 469 | def forward(self, x: Tensor, *, cond: Optional[Tensor] = None): |
nothing calls this directly
no test coverage detected