(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: float = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
device="cuda",
)
| 241 | |
| 242 | |
| 243 | def do_img2img( |
| 244 | img, |
| 245 | model, |
| 246 | sampler, |
| 247 | value_dict, |
| 248 | num_samples, |
| 249 | force_uc_zero_embeddings=[], |
| 250 | additional_kwargs={}, |
| 251 | offset_noise_level: float = 0.0, |
| 252 | return_latents=False, |
| 253 | skip_encode=False, |
| 254 | filter=None, |
| 255 | device="cuda", |
| 256 | ): |
| 257 | with torch.no_grad(): |
| 258 | with autocast(device) as precision_scope: |
| 259 | with model.ema_scope(): |
| 260 | batch, batch_uc = get_batch( |
| 261 | get_unique_embedder_keys_from_conditioner(model.conditioner), |
| 262 | value_dict, |
| 263 | [num_samples], |
| 264 | ) |
| 265 | c, uc = model.conditioner.get_unconditional_conditioning( |
| 266 | batch, |
| 267 | batch_uc=batch_uc, |
| 268 | force_uc_zero_embeddings=force_uc_zero_embeddings, |
| 269 | ) |
| 270 | |
| 271 | for k in c: |
| 272 | c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) |
| 273 | |
| 274 | for k in additional_kwargs: |
| 275 | c[k] = uc[k] = additional_kwargs[k] |
| 276 | if skip_encode: |
| 277 | z = img |
| 278 | else: |
| 279 | z = model.encode_first_stage(img) |
| 280 | noise = torch.randn_like(z) |
| 281 | sigmas = sampler.discretization(sampler.num_steps) |
| 282 | sigma = sigmas[0].to(z.device) |
| 283 | |
| 284 | if offset_noise_level > 0.0: |
| 285 | noise = noise + offset_noise_level * append_dims( |
| 286 | torch.randn(z.shape[0], device=z.device), z.ndim |
| 287 | ) |
| 288 | noised_z = z + noise * append_dims(sigma, z.ndim) |
| 289 | noised_z = noised_z / torch.sqrt( |
| 290 | 1.0 + sigmas[0] ** 2.0 |
| 291 | ) # Note: hardcoded to DDPM-like scaling. need to generalize later. |
| 292 | |
| 293 | def denoiser(x, sigma, c): |
| 294 | return model.denoiser(model.model, x, sigma, c) |
| 295 | |
| 296 | samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) |
| 297 | samples_x = model.decode_first_stage(samples_z) |
| 298 | samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) |
| 299 | |
| 300 | if filter is not None: |
no test coverage detected