MCPcopy Index your code
hub / github.com/huggingface/diffusers / get_mid_block

Function get_mid_block

src/diffusers/models/unets/unet_2d_blocks.py:252–324  ·  view source on GitHub ↗
(
    mid_block_type: str,
    temb_channels: int,
    in_channels: int,
    resnet_eps: float,
    resnet_act_fn: str,
    resnet_groups: int,
    output_scale_factor: float = 1.0,
    transformer_layers_per_block: int = 1,
    num_attention_heads: int | None = None,
    cross_attention_dim: int | None = None,
    dual_cross_attention: bool = False,
    use_linear_projection: bool = False,
    mid_block_only_cross_attention: bool = False,
    upcast_attention: bool = False,
    resnet_time_scale_shift: str = "default",
    attention_type: str = "default",
    resnet_skip_time_act: bool = False,
    cross_attention_norm: str | None = None,
    attention_head_dim: int | None = 1,
    dropout: float = 0.0,
)

Source from the content-addressed store, hash-verified

250
251
252def get_mid_block(
253 mid_block_type: str,
254 temb_channels: int,
255 in_channels: int,
256 resnet_eps: float,
257 resnet_act_fn: str,
258 resnet_groups: int,
259 output_scale_factor: float = 1.0,
260 transformer_layers_per_block: int = 1,
261 num_attention_heads: int | None = None,
262 cross_attention_dim: int | None = None,
263 dual_cross_attention: bool = False,
264 use_linear_projection: bool = False,
265 mid_block_only_cross_attention: bool = False,
266 upcast_attention: bool = False,
267 resnet_time_scale_shift: str = "default",
268 attention_type: str = "default",
269 resnet_skip_time_act: bool = False,
270 cross_attention_norm: str | None = None,
271 attention_head_dim: int | None = 1,
272 dropout: float = 0.0,
273):
274 if mid_block_type == "UNetMidBlock2DCrossAttn":
275 return UNetMidBlock2DCrossAttn(
276 transformer_layers_per_block=transformer_layers_per_block,
277 in_channels=in_channels,
278 temb_channels=temb_channels,
279 dropout=dropout,
280 resnet_eps=resnet_eps,
281 resnet_act_fn=resnet_act_fn,
282 output_scale_factor=output_scale_factor,
283 resnet_time_scale_shift=resnet_time_scale_shift,
284 cross_attention_dim=cross_attention_dim,
285 num_attention_heads=num_attention_heads,
286 resnet_groups=resnet_groups,
287 dual_cross_attention=dual_cross_attention,
288 use_linear_projection=use_linear_projection,
289 upcast_attention=upcast_attention,
290 attention_type=attention_type,
291 )
292 elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
293 return UNetMidBlock2DSimpleCrossAttn(
294 in_channels=in_channels,
295 temb_channels=temb_channels,
296 dropout=dropout,
297 resnet_eps=resnet_eps,
298 resnet_act_fn=resnet_act_fn,
299 output_scale_factor=output_scale_factor,
300 cross_attention_dim=cross_attention_dim,
301 attention_head_dim=attention_head_dim,
302 resnet_groups=resnet_groups,
303 resnet_time_scale_shift=resnet_time_scale_shift,
304 skip_time_act=resnet_skip_time_act,
305 only_cross_attention=mid_block_only_cross_attention,
306 cross_attention_norm=cross_attention_norm,
307 )
308 elif mid_block_type == "UNetMidBlock2D":
309 return UNetMidBlock2D(

Callers

nothing calls this directly

Calls 3

UNetMidBlock2DClass · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…