Run sampling.
(args)
| 49 | |
| 50 | |
| 51 | def main(args): |
| 52 | """ |
| 53 | Run sampling. |
| 54 | """ |
| 55 | torch.backends.cuda.matmul.allow_tf32 = True # True: fast but may lead to some small numerical differences |
| 56 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" |
| 57 | torch.set_grad_enabled(False) |
| 58 | |
| 59 | # Setup DDP: |
| 60 | dist.init_process_group("nccl") |
| 61 | rank = dist.get_rank() |
| 62 | device = rank % torch.cuda.device_count() |
| 63 | if args.seed: |
| 64 | seed = args.seed * dist.get_world_size() + rank |
| 65 | torch.manual_seed(seed) |
| 66 | torch.cuda.set_device(device) |
| 67 | # print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") |
| 68 | |
| 69 | if args.ckpt is None: |
| 70 | assert args.model == "Latte-XL/2", "Only Latte-XL/2 models are available for auto-download." |
| 71 | assert args.image_size in [256, 512] |
| 72 | assert args.num_classes == 1000 |
| 73 | |
| 74 | # Load model: |
| 75 | latent_size = args.image_size // 8 |
| 76 | args.latent_size = latent_size |
| 77 | model = get_models(args).to(device) |
| 78 | |
| 79 | if args.use_compile: |
| 80 | model = torch.compile(model) |
| 81 | |
| 82 | # a pre-trained model or load a custom Latte checkpoint from train.py: |
| 83 | ckpt_path = args.ckpt |
| 84 | state_dict = find_model(ckpt_path) |
| 85 | model.load_state_dict(state_dict) |
| 86 | model.eval() # important! |
| 87 | diffusion = create_diffusion(str(args.num_sampling_steps)) |
| 88 | # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device) |
| 89 | # vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device) |
| 90 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device) |
| 91 | |
| 92 | if args.use_fp16: |
| 93 | print('WARNING: using half percision for inferencing!') |
| 94 | vae.to(dtype=torch.float16) |
| 95 | model.to(dtype=torch.float16) |
| 96 | # text_encoder.to(dtype=torch.float16) |
| 97 | |
| 98 | assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0" |
| 99 | using_cfg = args.cfg_scale > 1.0 |
| 100 | |
| 101 | # Create folder to save samples: |
| 102 | # model_string_name = args.model.replace("/", "-") |
| 103 | # ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained" |
| 104 | # folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \ |
| 105 | # f"cfg-{args.cfg_scale}-seed-{args.seed}" |
| 106 | # sample_folder_dir = f"{args.sample_dir}/{folder_name}" |
| 107 | sample_folder_dir = args.save_video_path |
| 108 | if args.seed: |
no test coverage detected