(
self,
in_channels: int = 3,
out_channels: int = 3,
up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",),
block_out_channels: tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
mid_block_add_attention=True,
)
| 201 | """ |
| 202 | |
| 203 | def __init__( |
| 204 | self, |
| 205 | in_channels: int = 3, |
| 206 | out_channels: int = 3, |
| 207 | up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), |
| 208 | block_out_channels: tuple[int, ...] = (64,), |
| 209 | layers_per_block: int = 2, |
| 210 | norm_num_groups: int = 32, |
| 211 | act_fn: str = "silu", |
| 212 | norm_type: str = "group", # group, spatial |
| 213 | mid_block_add_attention=True, |
| 214 | ): |
| 215 | super().__init__() |
| 216 | self.layers_per_block = layers_per_block |
| 217 | |
| 218 | self.conv_in = nn.Conv2d( |
| 219 | in_channels, |
| 220 | block_out_channels[-1], |
| 221 | kernel_size=3, |
| 222 | stride=1, |
| 223 | padding=1, |
| 224 | ) |
| 225 | |
| 226 | self.up_blocks = nn.ModuleList([]) |
| 227 | |
| 228 | temb_channels = in_channels if norm_type == "spatial" else None |
| 229 | |
| 230 | # mid |
| 231 | self.mid_block = UNetMidBlock2D( |
| 232 | in_channels=block_out_channels[-1], |
| 233 | resnet_eps=1e-6, |
| 234 | resnet_act_fn=act_fn, |
| 235 | output_scale_factor=1, |
| 236 | resnet_time_scale_shift="default" if norm_type == "group" else norm_type, |
| 237 | attention_head_dim=block_out_channels[-1], |
| 238 | resnet_groups=norm_num_groups, |
| 239 | temb_channels=temb_channels, |
| 240 | add_attention=mid_block_add_attention, |
| 241 | ) |
| 242 | |
| 243 | # up |
| 244 | reversed_block_out_channels = list(reversed(block_out_channels)) |
| 245 | output_channel = reversed_block_out_channels[0] |
| 246 | for i, up_block_type in enumerate(up_block_types): |
| 247 | prev_output_channel = output_channel |
| 248 | output_channel = reversed_block_out_channels[i] |
| 249 | |
| 250 | is_final_block = i == len(block_out_channels) - 1 |
| 251 | |
| 252 | up_block = get_up_block( |
| 253 | up_block_type, |
| 254 | num_layers=self.layers_per_block + 1, |
| 255 | in_channels=prev_output_channel, |
| 256 | out_channels=output_channel, |
| 257 | prev_output_channel=prev_output_channel, |
| 258 | add_upsample=not is_final_block, |
| 259 | resnet_eps=1e-6, |
| 260 | resnet_act_fn=act_fn, |
nothing calls this directly
no test coverage detected