(args)
| 22 | from torchvision.utils import save_image |
| 23 | |
| 24 | def main(args): |
| 25 | # torch.manual_seed(args.seed) |
| 26 | torch.set_grad_enabled(False) |
| 27 | device = "cuda" if torch.cuda.is_available() else "cpu" |
| 28 | |
| 29 | transformer_model = get_models(args).to(device, dtype=torch.float16) |
| 30 | |
| 31 | if args.enable_vae_temporal_decoder: |
| 32 | vae = AutoencoderKLTemporalDecoder.from_pretrained(args.pretrained_model_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) |
| 33 | else: |
| 34 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae", torch_dtype=torch.float16).to(device) |
| 35 | tokenizer = T5Tokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") |
| 36 | text_encoder = T5EncoderModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) |
| 37 | |
| 38 | # set eval mode |
| 39 | transformer_model.eval() |
| 40 | vae.eval() |
| 41 | text_encoder.eval() |
| 42 | |
| 43 | if args.sample_method == 'DDIM': |
| 44 | scheduler = DDIMScheduler.from_pretrained(args.pretrained_model_path, |
| 45 | subfolder="scheduler", |
| 46 | beta_start=args.beta_start, |
| 47 | beta_end=args.beta_end, |
| 48 | beta_schedule=args.beta_schedule, |
| 49 | variance_type=args.variance_type, |
| 50 | clip_sample=False) |
| 51 | elif args.sample_method == 'EulerDiscrete': |
| 52 | scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_path, |
| 53 | subfolder="scheduler", |
| 54 | beta_start=args.beta_start, |
| 55 | beta_end=args.beta_end, |
| 56 | beta_schedule=args.beta_schedule, |
| 57 | variance_type=args.variance_type) |
| 58 | elif args.sample_method == 'DDPM': |
| 59 | scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_path, |
| 60 | subfolder="scheduler", |
| 61 | beta_start=args.beta_start, |
| 62 | beta_end=args.beta_end, |
| 63 | beta_schedule=args.beta_schedule, |
| 64 | variance_type=args.variance_type, |
| 65 | clip_sample=False) |
| 66 | elif args.sample_method == 'DPMSolverMultistep': |
| 67 | scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_path, |
| 68 | subfolder="scheduler", |
| 69 | beta_start=args.beta_start, |
| 70 | beta_end=args.beta_end, |
| 71 | beta_schedule=args.beta_schedule, |
| 72 | variance_type=args.variance_type) |
| 73 | elif args.sample_method == 'DPMSolverSinglestep': |
| 74 | scheduler = DPMSolverSinglestepScheduler.from_pretrained(args.pretrained_model_path, |
| 75 | subfolder="scheduler", |
| 76 | beta_start=args.beta_start, |
| 77 | beta_end=args.beta_end, |
| 78 | beta_schedule=args.beta_schedule, |
| 79 | variance_type=args.variance_type) |
| 80 | elif args.sample_method == 'PNDM': |
| 81 | scheduler = PNDMScheduler.from_pretrained(args.pretrained_model_path, |
no test coverage detected