MCPcopy Index your code
hub / github.com/huggingface/diffusers / UNetMidBlock2D

Class UNetMidBlock2D

src/diffusers/models/unets/unet_2d_blocks.py:589–748  ·  view source on GitHub ↗

A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. Args: in_channels (`int`): The number of input channels. temb_channels (`int`): The number of temporal embedding channels. dropout (`float`, *optional*, defaults to 0

Source from the content-addressed store, hash-verified

587
588
589class UNetMidBlock2D(nn.Module):
590 """
591 A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
592
593 Args:
594 in_channels (`int`): The number of input channels.
595 temb_channels (`int`): The number of temporal embedding channels.
596 dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
597 num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
598 resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
599 resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
600 The type of normalization to apply to the time embeddings. This can help to improve the performance of the
601 model on tasks with long-range temporal dependencies.
602 resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
603 resnet_groups (`int`, *optional*, defaults to 32):
604 The number of groups to use in the group normalization layers of the resnet blocks.
605 attn_groups (`int | None`, *optional*, defaults to None): The number of groups for the attention blocks.
606 resnet_pre_norm (`bool`, *optional*, defaults to `True`):
607 Whether to use pre-normalization for the resnet blocks.
608 add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
609 attention_head_dim (`int`, *optional*, defaults to 1):
610 Dimension of a single attention head. The number of attention heads is determined based on this value and
611 the number of input channels.
612 output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
613
614 Returns:
615 `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
616 height, width)`.
617
618 """
619
620 def __init__(
621 self,
622 in_channels: int,
623 temb_channels: int,
624 dropout: float = 0.0,
625 num_layers: int = 1,
626 resnet_eps: float = 1e-6,
627 resnet_time_scale_shift: str = "default", # default, spatial
628 resnet_act_fn: str = "swish",
629 resnet_groups: int = 32,
630 attn_groups: int | None = None,
631 resnet_pre_norm: bool = True,
632 add_attention: bool = True,
633 attention_head_dim: int = 1,
634 output_scale_factor: float = 1.0,
635 ):
636 super().__init__()
637 resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
638 self.add_attention = add_attention
639
640 if attn_groups is None:
641 attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
642
643 # there is always at least one resnet
644 if resnet_time_scale_shift == "spatial":
645 resnets = [
646 ResnetBlockCondNorm2D(

Callers 8

__init__Method · 0.90
__init__Method · 0.85
get_mid_blockFunction · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…