MCPcopy Index your code
hub / github.com/openai/guided-diffusion / __init__

Method __init__

guided_diffusion/unet.py:160–222  ·  view source on GitHub ↗
(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
        up=False,
        down=False,
    )

Source from the content-addressed store, hash-verified

158 """
159
160 def __init__(
161 self,
162 channels,
163 emb_channels,
164 dropout,
165 out_channels=None,
166 use_conv=False,
167 use_scale_shift_norm=False,
168 dims=2,
169 use_checkpoint=False,
170 up=False,
171 down=False,
172 ):
173 super().__init__()
174 self.channels = channels
175 self.emb_channels = emb_channels
176 self.dropout = dropout
177 self.out_channels = out_channels or channels
178 self.use_conv = use_conv
179 self.use_checkpoint = use_checkpoint
180 self.use_scale_shift_norm = use_scale_shift_norm
181
182 self.in_layers = nn.Sequential(
183 normalization(channels),
184 nn.SiLU(),
185 conv_nd(dims, channels, self.out_channels, 3, padding=1),
186 )
187
188 self.updown = up or down
189
190 if up:
191 self.h_upd = Upsample(channels, False, dims)
192 self.x_upd = Upsample(channels, False, dims)
193 elif down:
194 self.h_upd = Downsample(channels, False, dims)
195 self.x_upd = Downsample(channels, False, dims)
196 else:
197 self.h_upd = self.x_upd = nn.Identity()
198
199 self.emb_layers = nn.Sequential(
200 nn.SiLU(),
201 linear(
202 emb_channels,
203 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
204 ),
205 )
206 self.out_layers = nn.Sequential(
207 normalization(self.out_channels),
208 nn.SiLU(),
209 nn.Dropout(p=dropout),
210 zero_module(
211 conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
212 ),
213 )
214
215 if self.out_channels == channels:
216 self.skip_connection = nn.Identity()
217 elif use_conv:

Callers

nothing calls this directly

Calls 7

normalizationFunction · 0.85
conv_ndFunction · 0.85
UpsampleClass · 0.85
DownsampleClass · 0.85
linearFunction · 0.85
zero_moduleFunction · 0.85
__init__Method · 0.45

Tested by

no test coverage detected