(
self,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
spatial_transformer_attn_type="softmax",
adm_in_channels=None,
use_fairscale_checkpoint=False,
offload_to_cpu=False,
transformer_depth_middle=None,
dtype="fp32",
lora_init=False,
lora_rank=4,
lora_scale=1.0,
lora_weight_path=None,
)
| 524 | """ |
| 525 | |
| 526 | def __init__( |
| 527 | self, |
| 528 | in_channels, |
| 529 | model_channels, |
| 530 | out_channels, |
| 531 | num_res_blocks, |
| 532 | attention_resolutions, |
| 533 | dropout=0, |
| 534 | channel_mult=(1, 2, 4, 8), |
| 535 | conv_resample=True, |
| 536 | dims=2, |
| 537 | num_classes=None, |
| 538 | use_checkpoint=False, |
| 539 | use_fp16=False, |
| 540 | num_heads=-1, |
| 541 | num_head_channels=-1, |
| 542 | num_heads_upsample=-1, |
| 543 | use_scale_shift_norm=False, |
| 544 | resblock_updown=False, |
| 545 | use_new_attention_order=False, |
| 546 | use_spatial_transformer=False, # custom transformer support |
| 547 | transformer_depth=1, # custom transformer support |
| 548 | context_dim=None, # custom transformer support |
| 549 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model |
| 550 | legacy=True, |
| 551 | disable_self_attentions=None, |
| 552 | num_attention_blocks=None, |
| 553 | disable_middle_self_attn=False, |
| 554 | use_linear_in_transformer=False, |
| 555 | spatial_transformer_attn_type="softmax", |
| 556 | adm_in_channels=None, |
| 557 | use_fairscale_checkpoint=False, |
| 558 | offload_to_cpu=False, |
| 559 | transformer_depth_middle=None, |
| 560 | dtype="fp32", |
| 561 | lora_init=False, |
| 562 | lora_rank=4, |
| 563 | lora_scale=1.0, |
| 564 | lora_weight_path=None, |
| 565 | ): |
| 566 | super().__init__() |
| 567 | from omegaconf.listconfig import ListConfig |
| 568 | |
| 569 | self.dtype = str_to_dtype[dtype] |
| 570 | |
| 571 | if use_spatial_transformer: |
| 572 | assert ( |
| 573 | context_dim is not None |
| 574 | ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." |
| 575 | |
| 576 | if context_dim is not None: |
| 577 | assert ( |
| 578 | use_spatial_transformer |
| 579 | ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." |
| 580 | if type(context_dim) == ListConfig: |
| 581 | context_dim = list(context_dim) |
| 582 | |
| 583 | if num_heads_upsample == -1: |
nothing calls this directly
no test coverage detected