MCPcopy
hub / github.com/openai/improved-diffusion / ResBlock

Class ResBlock

improved_diffusion/unet.py:107–197  ·  view source on GitHub ↗

A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels.

Source from the content-addressed store, hash-verified

105
106
107class ResBlock(TimestepBlock):
108 """
109 A residual block that can optionally change the number of channels.
110
111 :param channels: the number of input channels.
112 :param emb_channels: the number of timestep embedding channels.
113 :param dropout: the rate of dropout.
114 :param out_channels: if specified, the number of out channels.
115 :param use_conv: if True and out_channels is specified, use a spatial
116 convolution instead of a smaller 1x1 convolution to change the
117 channels in the skip connection.
118 :param dims: determines if the signal is 1D, 2D, or 3D.
119 :param use_checkpoint: if True, use gradient checkpointing on this module.
120 """
121
122 def __init__(
123 self,
124 channels,
125 emb_channels,
126 dropout,
127 out_channels=None,
128 use_conv=False,
129 use_scale_shift_norm=False,
130 dims=2,
131 use_checkpoint=False,
132 ):
133 super().__init__()
134 self.channels = channels
135 self.emb_channels = emb_channels
136 self.dropout = dropout
137 self.out_channels = out_channels or channels
138 self.use_conv = use_conv
139 self.use_checkpoint = use_checkpoint
140 self.use_scale_shift_norm = use_scale_shift_norm
141
142 self.in_layers = nn.Sequential(
143 normalization(channels),
144 SiLU(),
145 conv_nd(dims, channels, self.out_channels, 3, padding=1),
146 )
147 self.emb_layers = nn.Sequential(
148 SiLU(),
149 linear(
150 emb_channels,
151 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
152 ),
153 )
154 self.out_layers = nn.Sequential(
155 normalization(self.out_channels),
156 SiLU(),
157 nn.Dropout(p=dropout),
158 zero_module(
159 conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
160 ),
161 )
162
163 if self.out_channels == channels:
164 self.skip_connection = nn.Identity()

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected