MCPcopy
hub / github.com/Vchitect/Latte / main

Function main

sample/sample_t2x.py:24–163  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

22from torchvision.utils import save_image
23
24def main(args):
25 # torch.manual_seed(args.seed)
26 torch.set_grad_enabled(False)
27 device = "cuda" if torch.cuda.is_available() else "cpu"
28
29 transformer_model = get_models(args).to(device, dtype=torch.float16)
30
31 if args.enable_vae_temporal_decoder:
32 vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
33 else:
34 vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device)
35 tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
36 text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device)
37
38 # set eval mode
39 transformer_model.eval()
40 vae.eval()
41 text_encoder.eval()
42
43 if args.sample_method == 'DDIM':
44 scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path,
45 subfolder="scheduler",
46 beta_start=args.beta_start,
47 beta_end=args.beta_end,
48 beta_schedule=args.beta_schedule,
49 variance_type=args.variance_type,
50 clip_sample=False)
51 elif args.sample_method == 'EulerDiscrete':
52 scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path,
53 subfolder="scheduler",
54 beta_start=args.beta_start,
55 beta_end=args.beta_end,
56 beta_schedule=args.beta_schedule,
57 variance_type=args.variance_type)
58 elif args.sample_method == 'DDPM':
59 scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path,
60 subfolder="scheduler",
61 beta_start=args.beta_start,
62 beta_end=args.beta_end,
63 beta_schedule=args.beta_schedule,
64 variance_type=args.variance_type,
65 clip_sample=False)
66 elif args.sample_method == 'DPMSolverMultistep':
67 scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path,
68 subfolder="scheduler",
69 beta_start=args.beta_start,
70 beta_end=args.beta_end,
71 beta_schedule=args.beta_schedule,
72 variance_type=args.variance_type)
73 elif args.sample_method == 'DPMSolverSinglestep':
74 scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path,
75 subfolder="scheduler",
76 beta_start=args.beta_start,
77 beta_end=args.beta_end,
78 beta_schedule=args.beta_schedule,
79 variance_type=args.variance_type)
80 elif args.sample_method == 'PNDM':
81 scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path,

Callers 1

sample_t2x.pyFile · 0.70

Calls 2

get_modelsFunction · 0.90
LattePipelineClass · 0.90

Tested by

no test coverage detected