MCPcopy
hub / github.com/openai/improved-diffusion / __init__

Method __init__

improved_diffusion/unet.py:122–170  ·  view source on GitHub ↗
(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
    )

Source from the content-addressed store, hash-verified

120 """
121
122 def __init__(
123 self,
124 channels,
125 emb_channels,
126 dropout,
127 out_channels=None,
128 use_conv=False,
129 use_scale_shift_norm=False,
130 dims=2,
131 use_checkpoint=False,
132 ):
133 super().__init__()
134 self.channels = channels
135 self.emb_channels = emb_channels
136 self.dropout = dropout
137 self.out_channels = out_channels or channels
138 self.use_conv = use_conv
139 self.use_checkpoint = use_checkpoint
140 self.use_scale_shift_norm = use_scale_shift_norm
141
142 self.in_layers = nn.Sequential(
143 normalization(channels),
144 SiLU(),
145 conv_nd(dims, channels, self.out_channels, 3, padding=1),
146 )
147 self.emb_layers = nn.Sequential(
148 SiLU(),
149 linear(
150 emb_channels,
151 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
152 ),
153 )
154 self.out_layers = nn.Sequential(
155 normalization(self.out_channels),
156 SiLU(),
157 nn.Dropout(p=dropout),
158 zero_module(
159 conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
160 ),
161 )
162
163 if self.out_channels == channels:
164 self.skip_connection = nn.Identity()
165 elif use_conv:
166 self.skip_connection = conv_nd(
167 dims, channels, self.out_channels, 3, padding=1
168 )
169 else:
170 self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
171
172 def forward(self, x, emb):
173 """

Callers

nothing calls this directly

Calls 6

normalizationFunction · 0.85
SiLUClass · 0.85
conv_ndFunction · 0.85
linearFunction · 0.85
zero_moduleFunction · 0.85
__init__Method · 0.45

Tested by

no test coverage detected