(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
)
| 425 | """ |
| 426 | |
| 427 | def __init__( |
| 428 | self, |
| 429 | image_size, |
| 430 | in_channels, |
| 431 | model_channels, |
| 432 | out_channels, |
| 433 | num_res_blocks, |
| 434 | attention_resolutions, |
| 435 | dropout=0, |
| 436 | channel_mult=(1, 2, 4, 8), |
| 437 | conv_resample=True, |
| 438 | dims=2, |
| 439 | num_classes=None, |
| 440 | use_checkpoint=False, |
| 441 | use_fp16=False, |
| 442 | num_heads=1, |
| 443 | num_head_channels=-1, |
| 444 | num_heads_upsample=-1, |
| 445 | use_scale_shift_norm=False, |
| 446 | resblock_updown=False, |
| 447 | use_new_attention_order=False, |
| 448 | ): |
| 449 | super().__init__() |
| 450 | |
| 451 | if num_heads_upsample == -1: |
| 452 | num_heads_upsample = num_heads |
| 453 | |
| 454 | self.image_size = image_size |
| 455 | self.in_channels = in_channels |
| 456 | self.model_channels = model_channels |
| 457 | self.out_channels = out_channels |
| 458 | self.num_res_blocks = num_res_blocks |
| 459 | self.attention_resolutions = attention_resolutions |
| 460 | self.dropout = dropout |
| 461 | self.channel_mult = channel_mult |
| 462 | self.conv_resample = conv_resample |
| 463 | self.num_classes = num_classes |
| 464 | self.use_checkpoint = use_checkpoint |
| 465 | self.dtype = th.float16 if use_fp16 else th.float32 |
| 466 | self.num_heads = num_heads |
| 467 | self.num_head_channels = num_head_channels |
| 468 | self.num_heads_upsample = num_heads_upsample |
| 469 | |
| 470 | time_embed_dim = model_channels * 4 |
| 471 | self.time_embed = nn.Sequential( |
| 472 | linear(model_channels, time_embed_dim), |
| 473 | nn.SiLU(), |
| 474 | linear(time_embed_dim, time_embed_dim), |
| 475 | ) |
| 476 | |
| 477 | if self.num_classes is not None: |
| 478 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) |
| 479 | |
| 480 | ch = input_ch = int(channel_mult[0] * model_channels) |
| 481 | self.input_blocks = nn.ModuleList( |
| 482 | [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] |
| 483 | ) |
| 484 | self._feature_size = ch |
nothing calls this directly
no test coverage detected