MCPcopy
hub / github.com/Vchitect/Latte / main

Function main

sample/sample_ddp.py:51–185  ·  view source on GitHub ↗

Run sampling.

(args)

Source from the content-addressed store, hash-verified

49
50
51def main(args):
52 """
53 Run sampling.
54 """
55 torch.backends.cuda.matmul.allow_tf32 = True # True: fast but may lead to some small numerical differences
56 assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
57 torch.set_grad_enabled(False)
58
59 # Setup DDP:
60 dist.init_process_group("nccl")
61 rank = dist.get_rank()
62 device = rank % torch.cuda.device_count()
63 if args.seed:
64 seed = args.seed * dist.get_world_size() + rank
65 torch.manual_seed(seed)
66 torch.cuda.set_device(device)
67 # print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
68
69 if args.ckpt is None:
70 assert args.model == "Latte-XL/2", "Only Latte-XL/2 models are available for auto-download."
71 assert args.image_size in [256, 512]
72 assert args.num_classes == 1000
73
74 # Load model:
75 latent_size = args.image_size // 8
76 args.latent_size = latent_size
77 model = get_models(args).to(device)
78
79 if args.use_compile:
80 model = torch.compile(model)
81
82 # a pre-trained model or load a custom Latte checkpoint from train.py:
83 ckpt_path = args.ckpt
84 state_dict = find_model(ckpt_path)
85 model.load_state_dict(state_dict)
86 model.eval() # important!
87 diffusion = create_diffusion(str(args.num_sampling_steps))
88 # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
89 # vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)
90 vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="sd-vae-ft-ema").to(device)
91
92 if args.use_fp16:
93 print('WARNING: using half percision for inferencing!')
94 vae.to(dtype=torch.float16)
95 model.to(dtype=torch.float16)
96 # text_encoder.to(dtype=torch.float16)
97
98 assert args.cfg_scale >= 1.0, "In almost all cases, cfg_scale be >= 1.0"
99 using_cfg = args.cfg_scale > 1.0
100
101 # Create folder to save samples:
102 # model_string_name = args.model.replace("/", "-")
103 # ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
104 # folder_name = f"{model_string_name}-{ckpt_string_name}-size-{args.image_size}-vae-{args.vae}-" \
105 # f"cfg-{args.cfg_scale}-seed-{args.seed}"
106 # sample_folder_dir = f"{args.sample_dir}/{folder_name}"
107 sample_folder_dir = args.save_video_path
108 if args.seed:

Callers 1

sample_ddp.pyFile · 0.70

Calls 5

get_modelsFunction · 0.90
find_modelFunction · 0.90
create_diffusionFunction · 0.90
ddim_sample_loopMethod · 0.80
p_sample_loopMethod · 0.80

Tested by

no test coverage detected