(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
)
| 510 | |
| 511 | |
| 512 | def do_sample( |
| 513 | model, |
| 514 | sampler, |
| 515 | value_dict, |
| 516 | num_samples, |
| 517 | H, |
| 518 | W, |
| 519 | C, |
| 520 | F, |
| 521 | force_uc_zero_embeddings: List = None, |
| 522 | batch2model_input: List = None, |
| 523 | return_latents=False, |
| 524 | filter=None, |
| 525 | ): |
| 526 | if force_uc_zero_embeddings is None: |
| 527 | force_uc_zero_embeddings = [] |
| 528 | if batch2model_input is None: |
| 529 | batch2model_input = [] |
| 530 | |
| 531 | st.text("Sampling") |
| 532 | |
| 533 | outputs = st.empty() |
| 534 | precision_scope = autocast |
| 535 | with torch.no_grad(): |
| 536 | with precision_scope("cuda"): |
| 537 | with model.ema_scope(): |
| 538 | num_samples = [num_samples] |
| 539 | load_model(model.conditioner) |
| 540 | batch, batch_uc = get_batch( |
| 541 | get_unique_embedder_keys_from_conditioner(model.conditioner), |
| 542 | value_dict, |
| 543 | num_samples, |
| 544 | ) |
| 545 | for key in batch: |
| 546 | if isinstance(batch[key], torch.Tensor): |
| 547 | print(key, batch[key].shape) |
| 548 | elif isinstance(batch[key], list): |
| 549 | print(key, [len(l) for l in batch[key]]) |
| 550 | else: |
| 551 | print(key, batch[key]) |
| 552 | c, uc = model.conditioner.get_unconditional_conditioning( |
| 553 | batch, |
| 554 | batch_uc=batch_uc, |
| 555 | force_uc_zero_embeddings=force_uc_zero_embeddings, |
| 556 | ) |
| 557 | unload_model(model.conditioner) |
| 558 | |
| 559 | for k in c: |
| 560 | if not k == "crossattn": |
| 561 | c[k], uc[k] = map( |
| 562 | lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc) |
| 563 | ) |
| 564 | |
| 565 | additional_model_inputs = {} |
| 566 | for k in batch2model_input: |
| 567 | additional_model_inputs[k] = batch[k] |
| 568 | |
| 569 | shape = (math.prod(num_samples), C, H // F, W // F) |
no test coverage detected