MCPcopy
hub / github.com/ermongroup/ddim / Model

Class Model

models/diffusion.py:192–341  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

190
191
192class Model(nn.Module):
193 def __init__(self, config):
194 super().__init__()
195 self.config = config
196 ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult)
197 num_res_blocks = config.model.num_res_blocks
198 attn_resolutions = config.model.attn_resolutions
199 dropout = config.model.dropout
200 in_channels = config.model.in_channels
201 resolution = config.data.image_size
202 resamp_with_conv = config.model.resamp_with_conv
203 num_timesteps = config.diffusion.num_diffusion_timesteps
204
205 if config.model.type == 'bayesian':
206 self.logvar = nn.Parameter(torch.zeros(num_timesteps))
207
208 self.ch = ch
209 self.temb_ch = self.ch*4
210 self.num_resolutions = len(ch_mult)
211 self.num_res_blocks = num_res_blocks
212 self.resolution = resolution
213 self.in_channels = in_channels
214
215 # timestep embedding
216 self.temb = nn.Module()
217 self.temb.dense = nn.ModuleList([
218 torch.nn.Linear(self.ch,
219 self.temb_ch),
220 torch.nn.Linear(self.temb_ch,
221 self.temb_ch),
222 ])
223
224 # downsampling
225 self.conv_in = torch.nn.Conv2d(in_channels,
226 self.ch,
227 kernel_size=3,
228 stride=1,
229 padding=1)
230
231 curr_res = resolution
232 in_ch_mult = (1,)+ch_mult
233 self.down = nn.ModuleList()
234 block_in = None
235 for i_level in range(self.num_resolutions):
236 block = nn.ModuleList()
237 attn = nn.ModuleList()
238 block_in = ch*in_ch_mult[i_level]
239 block_out = ch*ch_mult[i_level]
240 for i_block in range(self.num_res_blocks):
241 block.append(ResnetBlock(in_channels=block_in,
242 out_channels=block_out,
243 temb_channels=self.temb_ch,
244 dropout=dropout))
245 block_in = block_out
246 if curr_res in attn_resolutions:
247 attn.append(AttnBlock(block_in))
248 down = nn.Module()
249 down.block = block

Callers 2

trainMethod · 0.90
sampleMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected