(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
activation,
encoder_dim,
att_pool_heads,
encoder_channels,
image_size,
disable_self_attentions=None,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
precision='32',
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
efficient_activation=False,
scale_skip_connection=False,
)
| 351 | """ |
| 352 | |
| 353 | def __init__( |
| 354 | self, |
| 355 | in_channels, |
| 356 | model_channels, |
| 357 | out_channels, |
| 358 | num_res_blocks, |
| 359 | attention_resolutions, |
| 360 | activation, |
| 361 | encoder_dim, |
| 362 | att_pool_heads, |
| 363 | encoder_channels, |
| 364 | image_size, |
| 365 | disable_self_attentions=None, |
| 366 | dropout=0, |
| 367 | channel_mult=(1, 2, 4, 8), |
| 368 | conv_resample=True, |
| 369 | dims=2, |
| 370 | num_classes=None, |
| 371 | precision='32', |
| 372 | num_heads=1, |
| 373 | num_head_channels=-1, |
| 374 | num_heads_upsample=-1, |
| 375 | use_scale_shift_norm=False, |
| 376 | resblock_updown=False, |
| 377 | efficient_activation=False, |
| 378 | scale_skip_connection=False, |
| 379 | ): |
| 380 | super().__init__() |
| 381 | |
| 382 | if num_heads_upsample == -1: |
| 383 | num_heads_upsample = num_heads |
| 384 | |
| 385 | self.encoder_channels = encoder_channels |
| 386 | self.encoder_dim = encoder_dim |
| 387 | self.efficient_activation = efficient_activation |
| 388 | self.scale_skip_connection = scale_skip_connection |
| 389 | self.in_channels = in_channels |
| 390 | self.model_channels = model_channels |
| 391 | self.out_channels = out_channels |
| 392 | self.dropout = dropout |
| 393 | |
| 394 | # adapt attention resolutions |
| 395 | if isinstance(attention_resolutions, str): |
| 396 | self.attention_resolutions = [] |
| 397 | for res in attention_resolutions.split(','): |
| 398 | self.attention_resolutions.append(image_size // int(res)) |
| 399 | else: |
| 400 | self.attention_resolutions = attention_resolutions |
| 401 | self.attention_resolutions = tuple(self.attention_resolutions) |
| 402 | # |
| 403 | |
| 404 | # adapt disable self attention resolutions |
| 405 | if not disable_self_attentions: |
| 406 | self.disable_self_attentions = [] |
| 407 | elif disable_self_attentions is True: |
| 408 | self.disable_self_attentions = attention_resolutions |
| 409 | elif isinstance(disable_self_attentions, str): |
| 410 | self.disable_self_attentions = [] |
nothing calls this directly
no test coverage detected