MCPcopy
hub / github.com/AntixK/PyTorch-VAE

github.com/AntixK/PyTorch-VAE @main sqlite

repository ↗ · DeepWiki ↗
387 symbols 959 edges 49 files 129 documented · 33%
README

PyTorch VAE

  <a href="https://www.python.org/">
    <img src="https://img.shields.io/badge/Python-3.5-ff69b4.svg" /></a>
   <a href= "https://pytorch.org/">
    <img src="https://img.shields.io/badge/PyTorch-1.3-2BAF2B.svg" /></a>
   <a href= "https://github.com/AntixK/PyTorch-VAE/blob/master/LICENSE.md">
    <img src="https://img.shields.io/badge/license-Apache2.0-blue.svg" /></a>
     <a href= "https://twitter.com/intent/tweet?text=PyTorch-VAE:%20Collection%20of%20VAE%20models%20in%20PyTorch.&url=https://github.com/AntixK/PyTorch-VAE">
    <img src="https://img.shields.io/twitter/url/https/shields.io.svg?style=social" /></a>

Update 22/12/2021: Added support for PyTorch Lightning 1.5.6 version and cleaned up the code.

A collection of Variational AutoEncoders (VAEs) implemented in pytorch with focus on reproducibility. The aim of this project is to provide a quick and simple working example for many of the cool VAE models out there. All the models are trained on the CelebA dataset for consistency and comparison. The architecture of all the models are kept as similar as possible with the same layers, except for cases where the original paper necessitates a radically different architecture (Ex. VQ VAE uses Residual layers and no Batch-Norm, unlike other models). Here are the results of each model.

Requirements

  • Python >= 3.5
  • PyTorch >= 1.3
  • Pytorch Lightning >= 0.6.0 (GitHub Repo)
  • CUDA enabled computing device

Installation

$ git clone https://github.com/AntixK/PyTorch-VAE
$ cd PyTorch-VAE
$ pip install -r requirements.txt

Usage

$ cd PyTorch-VAE
$ python run.py -c configs/<config-file-name.yaml>

Config file template

model_params:
  name: "<name of VAE model>"
  in_channels: 3
  latent_dim: 
    .         # Other parameters required by the model
    .
    .

data_params:
  data_path: "<path to the celebA dataset>"
  train_batch_size: 64 # Better to have a square number
  val_batch_size:  64
  patch_size: 64  # Models are designed to work for this size
  num_workers: 4

exp_params:
  manual_seed: 1265
  LR: 0.005
  weight_decay:
    .         # Other arguments required for training, like scheduler etc.
    .
    .

trainer_params:
  gpus: 1         
  max_epochs: 100
  gradient_clip_val: 1.5
    .
    .
    .

logging_params:
  save_dir: "logs/"
  name: "<experiment name>"

View TensorBoard Logs

$ cd logs/<experiment name>/version_<the version you want>
$ tensorboard --logdir .

Note: The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the file from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference.


Results

Model Paper Reconstruction Samples
VAE (Code, Config) Link
Conditional VAE (Code, Config) Link
WAE - MMD (RBF Kernel) (Code, Config) Link
WAE - MMD (IMQ Kernel) (Code, Config) Link
Beta-VAE (Code, Config) Link
Disentangled Beta-VAE (Code, Config) Link
Beta-TC-VAE (Code, Config) Link
IWAE (K = 5) (Code, Config) Link
MIWAE (K = 5, M = 3) (Code, Config) Link
DFCVAE (Code, Config) Link
MSSIM VAE (Code, Config) Link
Categorical VAE (Code, Config) Link
Joint VAE (Code, Config) Link
Info VAE (Code, Config) Link
LogCosh VAE (Code, Config) Link
SWAE (200 Projections) (Code, Config) Link
VQ-VAE (K = 512, D = 64) (Code, Config) Link N/A
DIP VAE (Code, Config) Link

Contributing

If you have trained a better model, using these implementations, by fine-tuning the hyper-params in the config file, I would be happy to include your result (along with your config file) in this repo, citing your name 😊.

Additionally, if you would like to contribute some models, please submit a PR.

License

Apache License 2.0

Permissions Limitations Conditions
✔️ Commercial use ❌ Trademark use ⓘ License and copyright notice
✔️ Modification ❌ Liability ⓘ State changes
✔️ Distribution ❌ Warranty
✔️ Patent use
✔️ Private use

Citation

@misc{Subramanian2020,
  author = {Subramanian, A.K},
  title = {PyTorch-VAE},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/AntixK/PyTorch-VAE}}
}

Core symbols most depended-on inside this repo

loss_function
called by 23
models/hvae.py
sample
called by 13
models/hvae.py
generate
called by 5
models/hvae.py
reparameterize
called by 3
models/hvae.py
compute_kernel
called by 3
models/info_vae.py
reparameterize
called by 3
models/lvae.py
compute_kernel
called by 3
models/wae_mmd.py
log_density_gaussian
called by 3
models/betatc_vae.py

Shape

Method 330
Class 53
Function 4

Languages

Python100%

Modules by API surface

models/lvae.py18 symbols
dataset.py16 symbols
models/mssim_vae.py15 symbols
models/gamma_vae.py15 symbols
models/vq_vae.py14 symbols
models/info_vae.py13 symbols
models/wae_mmd.py12 symbols
models/hvae.py11 symbols
models/swae.py10 symbols
models/fvae.py10 symbols
models/dfcvae.py10 symbols
models/betatc_vae.py10 symbols

Dependencies from manifests, versioned

PyYAML6.0 · 1×
pytorch-lightning1.5.6 · 1×
tensorboard2.2.0 · 1×
torch1.6.1 · 1×
torchsummary1.5.1 · 1×
torchvision0.10.1 · 1×

For agents

$ claude mcp add PyTorch-VAE \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact