(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
num_heads=1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
)
| 299 | """ |
| 300 | |
| 301 | def __init__( |
| 302 | self, |
| 303 | in_channels, |
| 304 | model_channels, |
| 305 | out_channels, |
| 306 | num_res_blocks, |
| 307 | attention_resolutions, |
| 308 | dropout=0, |
| 309 | channel_mult=(1, 2, 4, 8), |
| 310 | conv_resample=True, |
| 311 | dims=2, |
| 312 | num_classes=None, |
| 313 | use_checkpoint=False, |
| 314 | num_heads=1, |
| 315 | num_heads_upsample=-1, |
| 316 | use_scale_shift_norm=False, |
| 317 | ): |
| 318 | super().__init__() |
| 319 | |
| 320 | if num_heads_upsample == -1: |
| 321 | num_heads_upsample = num_heads |
| 322 | |
| 323 | self.in_channels = in_channels |
| 324 | self.model_channels = model_channels |
| 325 | self.out_channels = out_channels |
| 326 | self.num_res_blocks = num_res_blocks |
| 327 | self.attention_resolutions = attention_resolutions |
| 328 | self.dropout = dropout |
| 329 | self.channel_mult = channel_mult |
| 330 | self.conv_resample = conv_resample |
| 331 | self.num_classes = num_classes |
| 332 | self.use_checkpoint = use_checkpoint |
| 333 | self.num_heads = num_heads |
| 334 | self.num_heads_upsample = num_heads_upsample |
| 335 | |
| 336 | time_embed_dim = model_channels * 4 |
| 337 | self.time_embed = nn.Sequential( |
| 338 | linear(model_channels, time_embed_dim), |
| 339 | SiLU(), |
| 340 | linear(time_embed_dim, time_embed_dim), |
| 341 | ) |
| 342 | |
| 343 | if self.num_classes is not None: |
| 344 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) |
| 345 | |
| 346 | self.input_blocks = nn.ModuleList( |
| 347 | [ |
| 348 | TimestepEmbedSequential( |
| 349 | conv_nd(dims, in_channels, model_channels, 3, padding=1) |
| 350 | ) |
| 351 | ] |
| 352 | ) |
| 353 | input_block_chans = [model_channels] |
| 354 | ch = model_channels |
| 355 | ds = 1 |
| 356 | for level, mult in enumerate(channel_mult): |
| 357 | for _ in range(num_res_blocks): |
| 358 | layers = [ |
nothing calls this directly
no test coverage detected