MCPcopy Index your code
hub / github.com/openai/guided-diffusion / UNetModel

Class UNetModel

guided_diffusion/unet.py:396–663  ·  view source on GitHub ↗

The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsam

Source from the content-addressed store, hash-verified

394
395
396class UNetModel(nn.Module):
397 """
398 The full UNet model with attention and timestep embedding.
399
400 :param in_channels: channels in the input Tensor.
401 :param model_channels: base channel count for the model.
402 :param out_channels: channels in the output Tensor.
403 :param num_res_blocks: number of residual blocks per downsample.
404 :param attention_resolutions: a collection of downsample rates at which
405 attention will take place. May be a set, list, or tuple.
406 For example, if this contains 4, then at 4x downsampling, attention
407 will be used.
408 :param dropout: the dropout probability.
409 :param channel_mult: channel multiplier for each level of the UNet.
410 :param conv_resample: if True, use learned convolutions for upsampling and
411 downsampling.
412 :param dims: determines if the signal is 1D, 2D, or 3D.
413 :param num_classes: if specified (as an int), then this model will be
414 class-conditional with `num_classes` classes.
415 :param use_checkpoint: use gradient checkpointing to reduce memory usage.
416 :param num_heads: the number of attention heads in each attention layer.
417 :param num_heads_channels: if specified, ignore num_heads and instead use
418 a fixed channel width per attention head.
419 :param num_heads_upsample: works with num_heads to set a different number
420 of heads for upsampling. Deprecated.
421 :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
422 :param resblock_updown: use residual blocks for up/downsampling.
423 :param use_new_attention_order: use a different attention pattern for potentially
424 increased efficiency.
425 """
426
427 def __init__(
428 self,
429 image_size,
430 in_channels,
431 model_channels,
432 out_channels,
433 num_res_blocks,
434 attention_resolutions,
435 dropout=0,
436 channel_mult=(1, 2, 4, 8),
437 conv_resample=True,
438 dims=2,
439 num_classes=None,
440 use_checkpoint=False,
441 use_fp16=False,
442 num_heads=1,
443 num_head_channels=-1,
444 num_heads_upsample=-1,
445 use_scale_shift_norm=False,
446 resblock_updown=False,
447 use_new_attention_order=False,
448 ):
449 super().__init__()
450
451 if num_heads_upsample == -1:
452 num_heads_upsample = num_heads
453

Callers 1

create_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected