(
self,
width,
height,
num_steps,
guidance,
seed,
prompt,
init_image=None,
image2image_strength=0.0,
add_sampling_metadata=True,
)
| 49 | |
| 50 | @torch.inference_mode() |
| 51 | def generate_image( |
| 52 | self, |
| 53 | width, |
| 54 | height, |
| 55 | num_steps, |
| 56 | guidance, |
| 57 | seed, |
| 58 | prompt, |
| 59 | init_image=None, |
| 60 | image2image_strength=0.0, |
| 61 | add_sampling_metadata=True, |
| 62 | ): |
| 63 | seed = int(seed) |
| 64 | if seed == -1: |
| 65 | seed = None |
| 66 | |
| 67 | opts = SamplingOptions( |
| 68 | prompt=prompt, |
| 69 | width=width, |
| 70 | height=height, |
| 71 | num_steps=num_steps, |
| 72 | guidance=guidance, |
| 73 | seed=seed, |
| 74 | ) |
| 75 | |
| 76 | if opts.seed is None: |
| 77 | opts.seed = torch.Generator(device="cpu").seed() |
| 78 | print(f"Generating '{opts.prompt}' with seed {opts.seed}") |
| 79 | t0 = time.perf_counter() |
| 80 | |
| 81 | if init_image is not None: |
| 82 | if isinstance(init_image, np.ndarray): |
| 83 | init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0 |
| 84 | init_image = init_image.unsqueeze(0) |
| 85 | init_image = init_image.to(self.device) |
| 86 | init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) |
| 87 | if self.offload: |
| 88 | self.ae.encoder.to(self.device) |
| 89 | init_image = self.ae.encode(init_image.to()) |
| 90 | if self.offload: |
| 91 | self.ae = self.ae.cpu() |
| 92 | torch.cuda.empty_cache() |
| 93 | |
| 94 | # prepare input |
| 95 | x = get_noise( |
| 96 | 1, |
| 97 | opts.height, |
| 98 | opts.width, |
| 99 | device=self.device, |
| 100 | dtype=torch.bfloat16, |
| 101 | seed=opts.seed, |
| 102 | ) |
| 103 | timesteps = get_schedule( |
| 104 | opts.num_steps, |
| 105 | x.shape[-1] * x.shape[-2] // 4, |
| 106 | shift=(not self.is_schnell), |
| 107 | ) |
| 108 | if init_image is not None: |
nothing calls this directly
no test coverage detected