— Official TensorFlow implementation

Training Generative Adversarial Networks with Limited Data
Tero Karras, Miika Aittala, Janne Hellsten, Samuli Laine, Jaakko Lehtinen, Timo Aila
https://arxiv.org/abs/2006.06676
Abstract: Training generative adversarial networks (GAN) using too little data typically leads to discriminator overfitting, causing training to diverge. We propose an adaptive discriminator augmentation mechanism that significantly stabilizes training in limited data regimes. The approach does not require changes to loss functions or network architectures, and is applicable both when training from scratch and when fine-tuning an existing GAN on another dataset. We demonstrate, on several datasets, that good results are now possible using only a few thousand training images, often matching StyleGAN2 results with an order of magnitude fewer images. We expect this to open up new application domains for GANs. We also find that the widely used CIFAR-10 is, in fact, a limited data benchmark, and improve the record FID from 5.59 to 2.42.
For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing
The Official PyTorch version is now available and supersedes the TensorFlow version. See the full list of versions here.
This repository supersedes the original StyleGAN2 with the following new features:
| Path | Description |
|---|---|
| stylegan2-ada | Main directory hosted on Amazon S3 |
| ├ ada-paper.pdf | Paper PDF |
| ├ images | Curated example images produced using the pre-trained models |
| ├ videos | Curated example interpolation videos |
| └ pretrained | Pre-trained models |
| ├ metfaces.pkl | MetFaces at 1024x1024, transfer learning from FFHQ using ADA |
| ├ brecahad.pkl | BreCaHAD at 512x512, trained from scratch using ADA |
| ├ afhqcat.pkl | AFHQ Cat at 512x512, trained from scratch using ADA |
| ├ afhqdog.pkl | AFHQ Dog at 512x512, trained from scratch using ADA |
| ├ afhqwild.pkl | AFHQ Wild at 512x512, trained from scratch using ADA |
| ├ cifar10.pkl | Class-conditional CIFAR-10 at 32x32 |
| ├ ffhq.pkl | FFHQ at 1024x1024, trained using original StyleGAN2 |
| ├ paper-fig7c-training-set-sweeps | All models used in Fig.7c (baseline, ADA, bCR) |
| ├ paper-fig8a-comparison-methods | All models used in Fig.8a (comparison methods) |
| ├ paper-fig8b-discriminator-capacity | All models used in Fig.8b (discriminator capacity) |
| ├ paper-fig11a-small-datasets | All models used in Fig.11a (small datasets, transfer learning) |
| ├ paper-fig11b-cifar10 | All models used in Fig.11b (CIFAR-10) |
| ├ transfer-learning-source-nets | Models used as starting point for transfer learning |
| └ metrics | Feature detectors used by the quality metrics |
The generator and discriminator networks rely heavily on custom TensorFlow ops that are compiled on the fly using NVCC. On Windows, the compilation requires Microsoft Visual Studio to be in PATH. We recommend installing Visual Studio Community Edition and adding it into PATH using "C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Auxiliary\Build\vcvars64.bat".
Pre-trained networks are stored as *.pkl files that can be referenced using local filenames or URLs:
# Generate curated MetFaces images without truncation (Fig.10 left)
python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
# Generate uncurated MetFaces images with truncation (Fig.12 upper left)
python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl
# Generate class conditional CIFAR-10 images (Fig.17 left, Car)
python generate.py --outdir=out --trunc=1 --seeds=0-35 --class=1 \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/cifar10.pkl
Outputs from the above commands are placed under out/*.png. You can change the location with --outdir. Temporary cache files, such as CUDA build results and downloaded network pickles, will be saved under $HOME/.cache/dnnlib. This can be overridden using the DNNLIB_CACHE_DIR environment variable.
Docker: You can run the above curated image example using Docker as follows:
docker build --tag stylegan2ada:latest .
docker run --gpus all -it --rm -v `pwd`:/scratch --user $(id -u):$(id -g) stylegan2ada:latest bash -c \
"(cd /scratch && DNNLIB_CACHE_DIR=/scratch/.cache python3 generate.py --trunc=1 --seeds=85,265,297,849 \
--outdir=out --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metfaces.pkl)"
Note: The above defaults to a container base image that requires NVIDIA driver release r455.23 or later. To build an image for older drivers and GPUs, run:
docker build --build-arg BASE_IMAGE=tensorflow/tensorflow:1.14.0-gpu-py3 --tag stylegan2ada:latest .
To find the matching latent vector for a given image file, run:
python projector.py --outdir=out --target=targetimg.png \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl
For optimal results, the target image should be cropped and aligned similar to the original FFHQ dataset. The above command saves the projection target out/target.png, result out/proj.png, latent vector out/dlatents.npz, and progression video out/proj.mp4. You can render the resulting latent vector by specifying --dlatents for python generate.py:
python generate.py --outdir=out --dlatents=out/dlatents.npz \
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl
Datasets are stored as multi-resolution TFRecords, i.e., the same format used by StyleGAN and StyleGAN2. Each dataset consists of multiple *.tfrecords files stored under a common directory, e.g., ~/datasets/ffhq/ffhq-r*.tfrecords
MetFaces: Download the MetFaces dataset and convert to TFRecords:
python dataset_tool.py create_from_images ~/datasets/metfaces ~/downloads/metfaces/images
python dataset_tool.py display ~/datasets/metfaces
BreCaHAD: Download the BreCaHAD dataset. Generate 512x512 resolution crops and convert to TFRecords:
python dataset_tool.py extract_brecahad_crops --cropsize=512 \
--output_dir=/tmp/brecahad-crops --brecahad_dir=~/downloads/brecahad/images
python dataset_tool.py create_from_images ~/datasets/brecahad /tmp/brecahad-crops
python dataset_tool.py display ~/datasets/brecahad
AFHQ: Download the AFHQ dataset and convert to TFRecords:
python dataset_tool.py create_from_images ~/datasets/afhqcat ~/downloads/afhq/train/cat
python dataset_tool.py create_from_images ~/datasets/afhqdog ~/downloads/afhq/train/dog
python dataset_tool.py create_from_images ~/datasets/afhqwild ~/downloads/afhq/train/wild
python dataset_tool.py display ~/datasets/afhqcat
CIFAR-10: Download the CIFAR-10 python version. Convert to two separate TFRecords for unconditional and class-conditional training:
python dataset_tool.py create_cifar10 --ignore_labels=1 \
~/datasets/cifar10u ~/downloads/cifar-10-batches-py
python dataset_tool.py create_cifar10 --ignore_labels=0 \
~/datasets/cifar10c ~/downloads/cifar-10-batches-py
python dataset_tool.py display ~/datasets/cifar10c
FFHQ: Download the Flickr-Faces-HQ dataset as TFRecords:
pushd ~
git clone https://github.com/NVlabs/ffhq-dataset.git
cd ffhq-dataset
python download_ffhq.py --tfrecords
popd
python dataset_tool.py display ~/ffhq-dataset/tfrecords/ffhq
LSUN: Download the desired LSUN categories in LMDB format from the LSUN project page and convert to TFRecords:
python dataset_tool.py create_lsun --resolution=256 --max_images=200000 \
~/datasets/lsuncat200k ~/downloads/lsun/cat_lmdb
python dataset_tool.py display ~/datasets/lsuncat200k
Custom: Custom datasets can be created by placing all images under a single directory. The images must be square-shaped and they must all have the same power-of-two dimensions. To convert the images to multi-resolution TFRecords, run:
python dataset_tool.py create_from_images ~/datasets/custom ~/custom-images
python dataset_tool.py display ~/datasets/custom
In its most basic form, training new networks boils down to:
python train.py --outdir=~/training-runs --gpus=1 --data=~/datasets/custom --dry-run
python train.py --outdir=~/training-runs --gpus=1 --data=~/datasets/custom
The first command is optional; it will validate the arguments, print out the resulting training configuration, and exit. The second command will kick off the actual training.
In this example, the results will be saved to a newly created directory ~/training-runs/<RUNNING_ID>-custom-auto1 (controlled by --outdir). The training will export network pickles (network-snapshot-<KIMG>.pkl) and example images (fakes<KIMG>.png) at regular intervals (controlled by --snap). For each pickle, it will also evaluate FID by default (controlled by --metrics) and log the resulting scores in metric-fid50k_full.txt.
The name of the output directory (e.g., 00000-custom-auto1) reflects the hyperparameter configuration that was used. In this case, custom indicates the training set (--data) and auto1 indicates the base configuration that was used to select the hyperparameters (--cfg):
| Base config | Description |
|---|---|
auto (default) |
Automatically select reasonable defaults based on resolution and GPU count. Serves as a good starting point for new datasets, but does not necessarily lead to optimal results. |
stylegan2 |
Reproduce results for StyleGAN2 config F at 1024x1024 using 1, 2, 4, or 8 GPUs. |
paper256 |
Reproduce results for FFHQ |
$ claude mcp add stylegan2-ada \
-- python -m otcore.mcp_server <graph>