r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library implements for all the model (such a
| 539 | |
| 540 | |
| 541 | class SlicingAutoencoderKL(ModelMixin, ConfigMixin): |
| 542 | r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma |
| 543 | and Max Welling. |
| 544 | |
| 545 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library |
| 546 | implements for all the model (such as downloading or saving, etc.) |
| 547 | |
| 548 | Parameters: |
| 549 | in_channels (int, *optional*, defaults to 3): Number of channels in the input image. |
| 550 | out_channels (int, *optional*, defaults to 3): Number of channels in the output. |
| 551 | down_block_types (`Tuple[str]`, *optional*, defaults to : |
| 552 | obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. |
| 553 | up_block_types (`Tuple[str]`, *optional*, defaults to : |
| 554 | obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. |
| 555 | block_out_channels (`Tuple[int]`, *optional*, defaults to : |
| 556 | obj:`(64,)`): Tuple of block output channels. |
| 557 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
| 558 | latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space. |
| 559 | sample_size (`int`, *optional*, defaults to `32`): TODO |
| 560 | """ |
| 561 | |
| 562 | @register_to_config |
| 563 | def __init__( |
| 564 | self, |
| 565 | in_channels: int = 3, |
| 566 | out_channels: int = 3, |
| 567 | down_block_types: Tuple[str] = ("DownEncoderBlock2D",), |
| 568 | up_block_types: Tuple[str] = ("UpDecoderBlock2D",), |
| 569 | block_out_channels: Tuple[int] = (64,), |
| 570 | layers_per_block: int = 1, |
| 571 | act_fn: str = "silu", |
| 572 | latent_channels: int = 4, |
| 573 | norm_num_groups: int = 32, |
| 574 | sample_size: int = 32, |
| 575 | num_slices: int = 16, |
| 576 | ): |
| 577 | super().__init__() |
| 578 | |
| 579 | # pass init params to Encoder |
| 580 | self.encoder = SlicingEncoder( |
| 581 | in_channels=in_channels, |
| 582 | out_channels=latent_channels, |
| 583 | down_block_types=down_block_types, |
| 584 | block_out_channels=block_out_channels, |
| 585 | layers_per_block=layers_per_block, |
| 586 | act_fn=act_fn, |
| 587 | norm_num_groups=norm_num_groups, |
| 588 | double_z=True, |
| 589 | num_slices=num_slices, |
| 590 | ) |
| 591 | |
| 592 | # pass init params to Decoder |
| 593 | self.decoder = SlicingDecoder( |
| 594 | in_channels=latent_channels, |
| 595 | out_channels=out_channels, |
| 596 | up_block_types=up_block_types, |
| 597 | block_out_channels=block_out_channels, |
| 598 | layers_per_block=layers_per_block, |