A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. Args: in_channels (`int`): The number of input channels. temb_channels (`int`): The number of temporal embedding channels. dropout (`float`, *optional*, defaults to 0
| 587 | |
| 588 | |
| 589 | class UNetMidBlock2D(nn.Module): |
| 590 | """ |
| 591 | A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. |
| 592 | |
| 593 | Args: |
| 594 | in_channels (`int`): The number of input channels. |
| 595 | temb_channels (`int`): The number of temporal embedding channels. |
| 596 | dropout (`float`, *optional*, defaults to 0.0): The dropout rate. |
| 597 | num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. |
| 598 | resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. |
| 599 | resnet_time_scale_shift (`str`, *optional*, defaults to `default`): |
| 600 | The type of normalization to apply to the time embeddings. This can help to improve the performance of the |
| 601 | model on tasks with long-range temporal dependencies. |
| 602 | resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. |
| 603 | resnet_groups (`int`, *optional*, defaults to 32): |
| 604 | The number of groups to use in the group normalization layers of the resnet blocks. |
| 605 | attn_groups (`int | None`, *optional*, defaults to None): The number of groups for the attention blocks. |
| 606 | resnet_pre_norm (`bool`, *optional*, defaults to `True`): |
| 607 | Whether to use pre-normalization for the resnet blocks. |
| 608 | add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. |
| 609 | attention_head_dim (`int`, *optional*, defaults to 1): |
| 610 | Dimension of a single attention head. The number of attention heads is determined based on this value and |
| 611 | the number of input channels. |
| 612 | output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. |
| 613 | |
| 614 | Returns: |
| 615 | `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels, |
| 616 | height, width)`. |
| 617 | |
| 618 | """ |
| 619 | |
| 620 | def __init__( |
| 621 | self, |
| 622 | in_channels: int, |
| 623 | temb_channels: int, |
| 624 | dropout: float = 0.0, |
| 625 | num_layers: int = 1, |
| 626 | resnet_eps: float = 1e-6, |
| 627 | resnet_time_scale_shift: str = "default", # default, spatial |
| 628 | resnet_act_fn: str = "swish", |
| 629 | resnet_groups: int = 32, |
| 630 | attn_groups: int | None = None, |
| 631 | resnet_pre_norm: bool = True, |
| 632 | add_attention: bool = True, |
| 633 | attention_head_dim: int = 1, |
| 634 | output_scale_factor: float = 1.0, |
| 635 | ): |
| 636 | super().__init__() |
| 637 | resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
| 638 | self.add_attention = add_attention |
| 639 | |
| 640 | if attn_groups is None: |
| 641 | attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None |
| 642 | |
| 643 | # there is always at least one resnet |
| 644 | if resnet_time_scale_shift == "spatial": |
| 645 | resnets = [ |
| 646 | ResnetBlockCondNorm2D( |
no outgoing calls
no test coverage detected
searching dependent graphs…