MCPcopy
hub / github.com/Stability-AI/generative-models / do_sample

Function do_sample

scripts/demo/streamlit_helpers.py:512–597  ·  view source on GitHub ↗
(
    model,
    sampler,
    value_dict,
    num_samples,
    H,
    W,
    C,
    F,
    force_uc_zero_embeddings: List = None,
    batch2model_input: List = None,
    return_latents=False,
    filter=None,
)

Source from the content-addressed store, hash-verified

510
511
512def do_sample(
513 model,
514 sampler,
515 value_dict,
516 num_samples,
517 H,
518 W,
519 C,
520 F,
521 force_uc_zero_embeddings: List = None,
522 batch2model_input: List = None,
523 return_latents=False,
524 filter=None,
525):
526 if force_uc_zero_embeddings is None:
527 force_uc_zero_embeddings = []
528 if batch2model_input is None:
529 batch2model_input = []
530
531 st.text("Sampling")
532
533 outputs = st.empty()
534 precision_scope = autocast
535 with torch.no_grad():
536 with precision_scope("cuda"):
537 with model.ema_scope():
538 num_samples = [num_samples]
539 load_model(model.conditioner)
540 batch, batch_uc = get_batch(
541 get_unique_embedder_keys_from_conditioner(model.conditioner),
542 value_dict,
543 num_samples,
544 )
545 for key in batch:
546 if isinstance(batch[key], torch.Tensor):
547 print(key, batch[key].shape)
548 elif isinstance(batch[key], list):
549 print(key, [len(l) for l in batch[key]])
550 else:
551 print(key, batch[key])
552 c, uc = model.conditioner.get_unconditional_conditioning(
553 batch,
554 batch_uc=batch_uc,
555 force_uc_zero_embeddings=force_uc_zero_embeddings,
556 )
557 unload_model(model.conditioner)
558
559 for k in c:
560 if not k == "crossattn":
561 c[k], uc[k] = map(
562 lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
563 )
564
565 additional_model_inputs = {}
566 for k in batch2model_input:
567 additional_model_inputs[k] = batch[k]
568
569 shape = (math.prod(num_samples), C, H // F, W // F)

Callers 1

run_txt2imgFunction · 0.70

Calls 7

load_modelFunction · 0.85
unload_modelFunction · 0.85
decode_first_stageMethod · 0.80
get_batchFunction · 0.70
ema_scopeMethod · 0.45

Tested by

no test coverage detected