MCPcopy
hub / github.com/kohya-ss/sd-scripts / SlicingAutoencoderKL

Class SlicingAutoencoderKL

library/slicing_vae.py:541–682  ·  view source on GitHub ↗

r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such a

Source from the content-addressed store, hash-verified

539
540
541class SlicingAutoencoderKL(ModelMixin, ConfigMixin):
542 r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
543 and Max Welling.
544
545 This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
546 implements for all the model (such as downloading or saving, etc.)
547
548 Parameters:
549 in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
550 out_channels (int, *optional*, defaults to 3): Number of channels in the output.
551 down_block_types (`Tuple[str]`, *optional*, defaults to :
552 obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
553 up_block_types (`Tuple[str]`, *optional*, defaults to :
554 obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
555 block_out_channels (`Tuple[int]`, *optional*, defaults to :
556 obj:`(64,)`): Tuple of block output channels.
557 act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
558 latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
559 sample_size (`int`, *optional*, defaults to `32`): TODO
560 """
561
562 @register_to_config
563 def __init__(
564 self,
565 in_channels: int = 3,
566 out_channels: int = 3,
567 down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
568 up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
569 block_out_channels: Tuple[int] = (64,),
570 layers_per_block: int = 1,
571 act_fn: str = "silu",
572 latent_channels: int = 4,
573 norm_num_groups: int = 32,
574 sample_size: int = 32,
575 num_slices: int = 16,
576 ):
577 super().__init__()
578
579 # pass init params to Encoder
580 self.encoder = SlicingEncoder(
581 in_channels=in_channels,
582 out_channels=latent_channels,
583 down_block_types=down_block_types,
584 block_out_channels=block_out_channels,
585 layers_per_block=layers_per_block,
586 act_fn=act_fn,
587 norm_num_groups=norm_num_groups,
588 double_z=True,
589 num_slices=num_slices,
590 )
591
592 # pass init params to Decoder
593 self.decoder = SlicingDecoder(
594 in_channels=latent_channels,
595 out_channels=out_channels,
596 up_block_types=up_block_types,
597 block_out_channels=block_out_channels,
598 layers_per_block=layers_per_block,

Callers 3

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected