MCPcopy
hub / github.com/PixArt-alpha/PixArt-sigma

github.com/PixArt-alpha/PixArt-sigma @main sqlite

repository ↗ · DeepWiki ↗
608 symbols 2,354 edges 85 files 192 documented · 32%
README

👉 PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation


This repo contains PyTorch model definitions, pre-trained weights and inference/sampling code for our paper exploring Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation. You can find more visualizations on our project page.

PixArt-Σ: Weak-to-Strong Training of Diffusion Transformer for 4K Text-to-Image Generation

Junsong Chen*, Chongjian Ge*, Enze Xie*†, Yue Wu*, Lewei Yao, Xiaozhe Ren, Zhongdao Wang, Ping Luo, Huchuan Lu, Zhenguo Li

Huawei Noah’s Ark Lab, DLUT, HKU, HKUST


Welcome everyone to contribute🔥🔥!!

Learning from the previous PixArt-α project, we will try to keep this repo as simple as possible so that everyone in the PixArt community can use it.


Breaking News 🔥🔥!!


Contents

-Main * Weak-to-Strong * Training * Inference * Use diffusers * Launch Demo * Available Models

-Guidance * Feature extraction* (Optional) * One step Generation (DMD) * LoRA & DoRA * [LCM: coming soon] * [ControlNet: coming soon] * [ComfyUI: coming soon] * Data reformat* (Optional)

-Others * Acknowledgement * Citation * TODO


🆚 Compare with PixArt-α

Model T5 token length VAE 2K/4K
PixArt-Σ 300 SDXL
PixArt-α 120 SD1.5
Model Sample-1 Sample-2 Sample-3
PixArt-Σ
PixArt-α
Prompt Close-up, gray-haired, bearded man in 60s, observing passersby, in wool coat and brown beret, glasses, cinematic. Body shot, a French woman, Photography, French Streets background, backlight, rim light, Fujifilm. Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.

Prompt Details

Sample-1 full prompt: An extreme close-up of an gray-haired man with a beard in his 60s, he is deep in thought pondering the history of the universe as he sits at a cafe in Paris, his eyes focus on people offscreen as they walk as he sits mostly motionless, he is dressed in a wool coat suit coat with a button-down shirt , he wears a brown beret and glasses and has a very professorial appearance, and the end he offers a subtle closed-mouth smile as if he found the answer to the mystery of life, the lighting is very cinematic with the golden light and the Parisian streets and city in the background, depth of field, cinematic 35mm film.

🔧 Dependencies and Installation

conda create -n pixart python==3.9.0
conda activate pixart
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

git clone https://github.com/PixArt-alpha/PixArt-sigma.git
cd PixArt-sigma
pip install -r requirements.txt

🔥 How to Train

1. PixArt Training

First of all.

We start a new repo to build a more user friendly and more compatible codebase. The main model structure is the same as PixArt-α, you can still develop your function base on the original repo. lso, This repo will support PixArt-alpha in the future.

[!TIP]
Now you can train your model without prior feature extraction. We reform the data structure in PixArt-α code base, so that everyone can start to train & inference & visualize at the very beginning without any pain.

1.1 Downloading the toy dataset

Download the toy dataset first. The dataset structure for training is:

cd ./pixart-sigma-toy-dataset

Dataset Structure
├──InternImgs/  (images are saved here)
│  ├──000000000000.png
│  ├──000000000001.png
│  ├──......
├──InternData/
│  ├──data_info.json    (meta data)
Optional(👇)
│  ├──img_sdxl_vae_features_1024resolution_ms_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npy
│  │  ├──000000000001.npy
│  │  ├──......
│  ├──caption_features_new
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......
│  ├──sharegpt4v_caption_features_new    (run tools/extract_caption_feature.py to generate caption T5 features, same name as images except .npz extension)
│  │  ├──000000000000.npz
│  │  ├──000000000001.npz
│  │  ├──......

1.2 Download pretrained checkpoint

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pretrained_models/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py # environment eg. HF_ENDPOINT=https://hf-mirror.com can use for HuggingFace mirror

1.3 You are ready to train!

Selecting your desired config file from config files dir.

python -m torch.distributed.launch --nproc_per_node=1 --master_port=12345 \
          train_scripts/train.py \
          configs/pixart_sigma_config/PixArt_sigma_xl2_img512_internalms.py \
          --load-from output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth \
          --work-dir output/your_first_pixart-exp \
          --debug

💻 How to Test

1. Quick start with Gradio

To get started, first install the required dependencies. Make sure you've downloaded the checkpoint files from models(coming soon) to the output/pretrained_models folder, and then run on your local machine:

# SDXL-VAE, T5 checkpoints
git lfs install
git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers output/pixart_sigma_sdxlvae_T5_diffusers

# PixArt-Sigma checkpoints
python tools/download.py

# demo launch
python scripts/interface.py --model_path output/pretrained_models/PixArt-Sigma-XL-2-512-MS.pth --image_size 512 --port 11223

2. Integration in diffusers

[!IMPORTANT]
Upgrade your diffusers to make the PixArtSigmaPipeline available! bash pip install git+https://github.com/huggingface/diffusers

For diffusers<0.28.0, check this script for help.

import torch
from diffusers import Transformer2DModel, PixArtSigmaPipeline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float16

transformer = Transformer2DModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
    subfolder='transformer', 
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
    transformer=transformer,
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe.to(device)

# Enable memory optimizations.
# pipe.enable_model_cpu_offload()

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]
image.save("./catcus.png")

3. PixArt Demo

pip install git+https://github.com/huggingface/diffusers

# PixArt-Sigma 1024px
DEMO_PORT=12345 python app/app_pixart_sigma.py

# PixArt-Sigma One step Sampler(DMD)
DEMO_PORT=12345 python app/app_pixart_dmd.py

Let's have a look at a simple example using the http://your-server-ip:12345.

4. Convert .pth checkpoint into diffusers version

Directly download from Hugging Face

or run with:

pip install git+https://github.com/huggingface/diffusers

python tools/convert_pixart_to_diffusers.py --orig_ckpt_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS.pth --dump_path output/pretrained_models/PixArt-Sigma-XL-2-1024-MS --only_transformer=True --image_size=1024 --version sigma

⏬ Available Models

All models will be automatically downloaded here. You can also choose to download manually from this url.

| Model | #Params | Checkpoint path

Core symbols most depended-on inside this repo

to
called by 246
train_scripts/train_pixart_lcm.py
log
called by 49
diffusion/utils/misc.py
randn
called by 28
diffusion/model/utils.py
get_root_logger
called by 23
diffusion/utils/logger.py
_extract_into_tensor
called by 23
diffusion/model/gaussian_diffusion.py
marginal_lambda
called by 22
diffusion/model/sa_solver.py
marginal_std
called by 19
diffusion/model/dpm_solver.py
marginal_lambda
called by 19
diffusion/model/dpm_solver.py

Shape

Method 316
Function 221
Class 71

Languages

Python100%

Modules by API surface

diffusion/model/nets/PixArt_blocks.py44 symbols
diffusion/model/utils.py38 symbols
diffusion/model/sa_solver.py35 symbols
diffusion/model/dpm_solver.py35 symbols
diffusion/model/gaussian_diffusion.py33 symbols
diffusion/model/llava/mpt/modeling_mpt.py26 symbols
diffusion/utils/misc.py23 symbols
diffusion/utils/dist_utils.py23 symbols
diffusion/data/datasets/InternalData.py23 symbols
diffusion/sa_solver_diffusers.py19 symbols
diffusion/model/nets/PixArt.py16 symbols
diffusion/model/timestep_sampler.py15 symbols

Dependencies from manifests, versioned

accelerate0.25.0 · 1×
gradio4.1.1 · 1×
mmcv1.7.0 · 1×
protobuf3.20.2 · 1×
sentencepiece0.1.99 · 1×
timm0.6.12 · 1×
transformers4.36.1 · 1×
xformers0.0.19 · 1×
yapf0.40.1 · 1×

For agents

$ claude mcp add PixArt-sigma \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact