MCPcopy
hub / github.com/openai/guided-diffusion / ResBlock

Class ResBlock

guided_diffusion/unet.py:143–256  ·  view source on GitHub ↗

A residual block that can optionally change the number of channels. :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels.

Source from the content-addressed store, hash-verified

141
142
143class ResBlock(TimestepBlock):
144 """
145 A residual block that can optionally change the number of channels.
146
147 :param channels: the number of input channels.
148 :param emb_channels: the number of timestep embedding channels.
149 :param dropout: the rate of dropout.
150 :param out_channels: if specified, the number of out channels.
151 :param use_conv: if True and out_channels is specified, use a spatial
152 convolution instead of a smaller 1x1 convolution to change the
153 channels in the skip connection.
154 :param dims: determines if the signal is 1D, 2D, or 3D.
155 :param use_checkpoint: if True, use gradient checkpointing on this module.
156 :param up: if True, use this block for upsampling.
157 :param down: if True, use this block for downsampling.
158 """
159
160 def __init__(
161 self,
162 channels,
163 emb_channels,
164 dropout,
165 out_channels=None,
166 use_conv=False,
167 use_scale_shift_norm=False,
168 dims=2,
169 use_checkpoint=False,
170 up=False,
171 down=False,
172 ):
173 super().__init__()
174 self.channels = channels
175 self.emb_channels = emb_channels
176 self.dropout = dropout
177 self.out_channels = out_channels or channels
178 self.use_conv = use_conv
179 self.use_checkpoint = use_checkpoint
180 self.use_scale_shift_norm = use_scale_shift_norm
181
182 self.in_layers = nn.Sequential(
183 normalization(channels),
184 nn.SiLU(),
185 conv_nd(dims, channels, self.out_channels, 3, padding=1),
186 )
187
188 self.updown = up or down
189
190 if up:
191 self.h_upd = Upsample(channels, False, dims)
192 self.x_upd = Upsample(channels, False, dims)
193 elif down:
194 self.h_upd = Downsample(channels, False, dims)
195 self.x_upd = Downsample(channels, False, dims)
196 else:
197 self.h_upd = self.x_upd = nn.Identity()
198
199 self.emb_layers = nn.Sequential(
200 nn.SiLU(),

Callers 2

__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected