r""" A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sam
| 37 | |
| 38 | |
| 39 | class UNet2DModel(ModelMixin, ConfigMixin): |
| 40 | r""" |
| 41 | A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. |
| 42 | |
| 43 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
| 44 | for all models (such as downloading or saving). |
| 45 | |
| 46 | Parameters: |
| 47 | sample_size (`int` or `tuple[int, int]`, *optional*, defaults to `None`): |
| 48 | Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - |
| 49 | 1)`. |
| 50 | in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. |
| 51 | out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. |
| 52 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. |
| 53 | time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. |
| 54 | freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. |
| 55 | flip_sin_to_cos (`bool`, *optional*, defaults to `True`): |
| 56 | Whether to flip sin to cos for Fourier time embedding. |
| 57 | down_block_types (`tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): |
| 58 | tuple of downsample block types. |
| 59 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): |
| 60 | Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`. |
| 61 | up_block_types (`tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): |
| 62 | tuple of upsample block types. |
| 63 | block_out_channels (`tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): |
| 64 | tuple of block output channels. |
| 65 | layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. |
| 66 | mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. |
| 67 | downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. |
| 68 | downsample_type (`str`, *optional*, defaults to `conv`): |
| 69 | The downsample type for downsampling layers. Choose between "conv" and "resnet" |
| 70 | upsample_type (`str`, *optional*, defaults to `conv`): |
| 71 | The upsample type for upsampling layers. Choose between "conv" and "resnet" |
| 72 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
| 73 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
| 74 | attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. |
| 75 | norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. |
| 76 | attn_norm_num_groups (`int`, *optional*, defaults to `None`): |
| 77 | If set to an integer, a group norm layer will be created in the mid block's [`Attention`] layer with the |
| 78 | given number of groups. If left as `None`, the group norm layer will only be created if |
| 79 | `resnet_time_scale_shift` is set to `default`, and if created will have `norm_num_groups` groups. |
| 80 | norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. |
| 81 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config |
| 82 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. |
| 83 | class_embed_type (`str`, *optional*, defaults to `None`): |
| 84 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, |
| 85 | `"timestep"`, or `"identity"`. |
| 86 | num_class_embeds (`int`, *optional*, defaults to `None`): |
| 87 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class |
| 88 | conditioning with `class_embed_type` equal to `None`. |
| 89 | """ |
| 90 | |
| 91 | _supports_gradient_checkpointing = True |
| 92 | _skip_layerwise_casting_patterns = ["norm"] |
| 93 | |
| 94 | @register_to_config |
| 95 | def __init__( |
| 96 | self, |
no outgoing calls
searching dependent graphs…