| 134 | |
| 135 | |
| 136 | class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin): |
| 137 | _supports_gradient_checkpointing = True |
| 138 | |
| 139 | @register_to_config |
| 140 | def __init__( |
| 141 | self, |
| 142 | in_channels: int = 16, |
| 143 | out_channels: int = 16, |
| 144 | timestep_ratio_embedding_dim: int = 64, |
| 145 | patch_size: int = 1, |
| 146 | conditioning_dim: int = 2048, |
| 147 | block_out_channels: tuple[int, ...] = (2048, 2048), |
| 148 | num_attention_heads: tuple[int, ...] = (32, 32), |
| 149 | down_num_layers_per_block: tuple[int, ...] = (8, 24), |
| 150 | up_num_layers_per_block: tuple[int, ...] = (24, 8), |
| 151 | down_blocks_repeat_mappers: tuple[int] | None = ( |
| 152 | 1, |
| 153 | 1, |
| 154 | ), |
| 155 | up_blocks_repeat_mappers: tuple[int] | None = (1, 1), |
| 156 | block_types_per_layer: tuple[tuple[str]] = ( |
| 157 | ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), |
| 158 | ("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"), |
| 159 | ), |
| 160 | clip_text_in_channels: int | None = None, |
| 161 | clip_text_pooled_in_channels=1280, |
| 162 | clip_image_in_channels: int | None = None, |
| 163 | clip_seq=4, |
| 164 | effnet_in_channels: int | None = None, |
| 165 | pixel_mapper_in_channels: int | None = None, |
| 166 | kernel_size=3, |
| 167 | dropout: float | tuple[float] = (0.1, 0.1), |
| 168 | self_attn: bool | tuple[bool] = True, |
| 169 | timestep_conditioning_type: tuple[str, ...] = ("sca", "crp"), |
| 170 | switch_level: tuple[bool] | None = None, |
| 171 | ): |
| 172 | """ |
| 173 | |
| 174 | Parameters: |
| 175 | in_channels (`int`, defaults to 16): |
| 176 | Number of channels in the input sample. |
| 177 | out_channels (`int`, defaults to 16): |
| 178 | Number of channels in the output sample. |
| 179 | timestep_ratio_embedding_dim (`int`, defaults to 64): |
| 180 | Dimension of the projected time embedding. |
| 181 | patch_size (`int`, defaults to 1): |
| 182 | Patch size to use for pixel unshuffling layer |
| 183 | conditioning_dim (`int`, defaults to 2048): |
| 184 | Dimension of the image and text conditional embedding. |
| 185 | block_out_channels (tuple[int], defaults to (2048, 2048)): |
| 186 | tuple of output channels for each block. |
| 187 | num_attention_heads (tuple[int], defaults to (32, 32)): |
| 188 | Number of attention heads in each attention block. Set to -1 to if block types in a layer do not have |
| 189 | attention. |
| 190 | down_num_layers_per_block (tuple[int], defaults to [8, 24]): |
| 191 | Number of layers in each down block. |
| 192 | up_num_layers_per_block (tuple[int], defaults to [24, 8]): |
| 193 | Number of layers in each up block. |
no outgoing calls
searching dependent graphs…