MCPcopy
hub / github.com/Stability-AI/generative-models / do_img2img

Function do_img2img

sgm/inference/helpers.py:243–305  ·  view source on GitHub ↗
(
    img,
    model,
    sampler,
    value_dict,
    num_samples,
    force_uc_zero_embeddings=[],
    additional_kwargs={},
    offset_noise_level: float = 0.0,
    return_latents=False,
    skip_encode=False,
    filter=None,
    device="cuda",
)

Source from the content-addressed store, hash-verified

241
242
243def do_img2img(
244 img,
245 model,
246 sampler,
247 value_dict,
248 num_samples,
249 force_uc_zero_embeddings=[],
250 additional_kwargs={},
251 offset_noise_level: float = 0.0,
252 return_latents=False,
253 skip_encode=False,
254 filter=None,
255 device="cuda",
256):
257 with torch.no_grad():
258 with autocast(device) as precision_scope:
259 with model.ema_scope():
260 batch, batch_uc = get_batch(
261 get_unique_embedder_keys_from_conditioner(model.conditioner),
262 value_dict,
263 [num_samples],
264 )
265 c, uc = model.conditioner.get_unconditional_conditioning(
266 batch,
267 batch_uc=batch_uc,
268 force_uc_zero_embeddings=force_uc_zero_embeddings,
269 )
270
271 for k in c:
272 c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
273
274 for k in additional_kwargs:
275 c[k] = uc[k] = additional_kwargs[k]
276 if skip_encode:
277 z = img
278 else:
279 z = model.encode_first_stage(img)
280 noise = torch.randn_like(z)
281 sigmas = sampler.discretization(sampler.num_steps)
282 sigma = sigmas[0].to(z.device)
283
284 if offset_noise_level > 0.0:
285 noise = noise + offset_noise_level * append_dims(
286 torch.randn(z.shape[0], device=z.device), z.ndim
287 )
288 noised_z = z + noise * append_dims(sigma, z.ndim)
289 noised_z = noised_z / torch.sqrt(
290 1.0 + sigmas[0] ** 2.0
291 ) # Note: hardcoded to DDPM-like scaling. need to generalize later.
292
293 def denoiser(x, sigma, c):
294 return model.denoiser(model.model, x, sigma, c)
295
296 samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
297 samples_x = model.decode_first_stage(samples_z)
298 samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
299
300 if filter is not None:

Callers 2

image_to_imageMethod · 0.90
refinerMethod · 0.90

Calls 8

append_dimsFunction · 0.90
autocastFunction · 0.85
encode_first_stageMethod · 0.80
decode_first_stageMethod · 0.80
get_batchFunction · 0.70
ema_scopeMethod · 0.45

Tested by

no test coverage detected