MCPcopy Index your code
hub / github.com/openai/improved-diffusion / UNetModel

Class UNetModel

improved_diffusion/unet.py:278–523  ·  view source on GitHub ↗

The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsam

Source from the content-addressed store, hash-verified

276
277
278class UNetModel(nn.Module):
279 """
280 The full UNet model with attention and timestep embedding.
281
282 :param in_channels: channels in the input Tensor.
283 :param model_channels: base channel count for the model.
284 :param out_channels: channels in the output Tensor.
285 :param num_res_blocks: number of residual blocks per downsample.
286 :param attention_resolutions: a collection of downsample rates at which
287 attention will take place. May be a set, list, or tuple.
288 For example, if this contains 4, then at 4x downsampling, attention
289 will be used.
290 :param dropout: the dropout probability.
291 :param channel_mult: channel multiplier for each level of the UNet.
292 :param conv_resample: if True, use learned convolutions for upsampling and
293 downsampling.
294 :param dims: determines if the signal is 1D, 2D, or 3D.
295 :param num_classes: if specified (as an int), then this model will be
296 class-conditional with `num_classes` classes.
297 :param use_checkpoint: use gradient checkpointing to reduce memory usage.
298 :param num_heads: the number of attention heads in each attention layer.
299 """
300
301 def __init__(
302 self,
303 in_channels,
304 model_channels,
305 out_channels,
306 num_res_blocks,
307 attention_resolutions,
308 dropout=0,
309 channel_mult=(1, 2, 4, 8),
310 conv_resample=True,
311 dims=2,
312 num_classes=None,
313 use_checkpoint=False,
314 num_heads=1,
315 num_heads_upsample=-1,
316 use_scale_shift_norm=False,
317 ):
318 super().__init__()
319
320 if num_heads_upsample == -1:
321 num_heads_upsample = num_heads
322
323 self.in_channels = in_channels
324 self.model_channels = model_channels
325 self.out_channels = out_channels
326 self.num_res_blocks = num_res_blocks
327 self.attention_resolutions = attention_resolutions
328 self.dropout = dropout
329 self.channel_mult = channel_mult
330 self.conv_resample = conv_resample
331 self.num_classes = num_classes
332 self.use_checkpoint = use_checkpoint
333 self.num_heads = num_heads
334 self.num_heads_upsample = num_heads_upsample
335

Callers 1

create_modelFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected