(
mid_block_type: str,
temb_channels: int,
in_channels: int,
resnet_eps: float,
resnet_act_fn: str,
resnet_groups: int,
output_scale_factor: float = 1.0,
transformer_layers_per_block: int = 1,
num_attention_heads: int | None = None,
cross_attention_dim: int | None = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
mid_block_only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
cross_attention_norm: str | None = None,
attention_head_dim: int | None = 1,
dropout: float = 0.0,
)
| 250 | |
| 251 | |
| 252 | def get_mid_block( |
| 253 | mid_block_type: str, |
| 254 | temb_channels: int, |
| 255 | in_channels: int, |
| 256 | resnet_eps: float, |
| 257 | resnet_act_fn: str, |
| 258 | resnet_groups: int, |
| 259 | output_scale_factor: float = 1.0, |
| 260 | transformer_layers_per_block: int = 1, |
| 261 | num_attention_heads: int | None = None, |
| 262 | cross_attention_dim: int | None = None, |
| 263 | dual_cross_attention: bool = False, |
| 264 | use_linear_projection: bool = False, |
| 265 | mid_block_only_cross_attention: bool = False, |
| 266 | upcast_attention: bool = False, |
| 267 | resnet_time_scale_shift: str = "default", |
| 268 | attention_type: str = "default", |
| 269 | resnet_skip_time_act: bool = False, |
| 270 | cross_attention_norm: str | None = None, |
| 271 | attention_head_dim: int | None = 1, |
| 272 | dropout: float = 0.0, |
| 273 | ): |
| 274 | if mid_block_type == "UNetMidBlock2DCrossAttn": |
| 275 | return UNetMidBlock2DCrossAttn( |
| 276 | transformer_layers_per_block=transformer_layers_per_block, |
| 277 | in_channels=in_channels, |
| 278 | temb_channels=temb_channels, |
| 279 | dropout=dropout, |
| 280 | resnet_eps=resnet_eps, |
| 281 | resnet_act_fn=resnet_act_fn, |
| 282 | output_scale_factor=output_scale_factor, |
| 283 | resnet_time_scale_shift=resnet_time_scale_shift, |
| 284 | cross_attention_dim=cross_attention_dim, |
| 285 | num_attention_heads=num_attention_heads, |
| 286 | resnet_groups=resnet_groups, |
| 287 | dual_cross_attention=dual_cross_attention, |
| 288 | use_linear_projection=use_linear_projection, |
| 289 | upcast_attention=upcast_attention, |
| 290 | attention_type=attention_type, |
| 291 | ) |
| 292 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": |
| 293 | return UNetMidBlock2DSimpleCrossAttn( |
| 294 | in_channels=in_channels, |
| 295 | temb_channels=temb_channels, |
| 296 | dropout=dropout, |
| 297 | resnet_eps=resnet_eps, |
| 298 | resnet_act_fn=resnet_act_fn, |
| 299 | output_scale_factor=output_scale_factor, |
| 300 | cross_attention_dim=cross_attention_dim, |
| 301 | attention_head_dim=attention_head_dim, |
| 302 | resnet_groups=resnet_groups, |
| 303 | resnet_time_scale_shift=resnet_time_scale_shift, |
| 304 | skip_time_act=resnet_skip_time_act, |
| 305 | only_cross_attention=mid_block_only_cross_attention, |
| 306 | cross_attention_norm=cross_attention_norm, |
| 307 | ) |
| 308 | elif mid_block_type == "UNetMidBlock2D": |
| 309 | return UNetMidBlock2D( |
nothing calls this directly
no test coverage detected
searching dependent graphs…