MCPcopy
hub / github.com/Stability-AI/generative-models / get_batch

Function get_batch

scripts/demo/streamlit_helpers.py:600–654  ·  view source on GitHub ↗
(keys, value_dict, N: Union[List, ListConfig], device="cuda")

Source from the content-addressed store, hash-verified

598
599
600def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
601 # Hardcoded demo setups; might undergo some changes in the future
602
603 batch = {}
604 batch_uc = {}
605
606 for key in keys:
607 if key == "txt":
608 batch["txt"] = (
609 np.repeat([value_dict["prompt"]], repeats=math.prod(N))
610 .reshape(N)
611 .tolist()
612 )
613 batch_uc["txt"] = (
614 np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
615 .reshape(N)
616 .tolist()
617 )
618 elif key == "original_size_as_tuple":
619 batch["original_size_as_tuple"] = (
620 torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
621 .to(device)
622 .repeat(*N, 1)
623 )
624 elif key == "crop_coords_top_left":
625 batch["crop_coords_top_left"] = (
626 torch.tensor(
627 [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
628 )
629 .to(device)
630 .repeat(*N, 1)
631 )
632 elif key == "aesthetic_score":
633 batch["aesthetic_score"] = (
634 torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
635 )
636 batch_uc["aesthetic_score"] = (
637 torch.tensor([value_dict["negative_aesthetic_score"]])
638 .to(device)
639 .repeat(*N, 1)
640 )
641
642 elif key == "target_size_as_tuple":
643 batch["target_size_as_tuple"] = (
644 torch.tensor([value_dict["target_height"], value_dict["target_width"]])
645 .to(device)
646 .repeat(*N, 1)
647 )
648 else:
649 batch[key] = value_dict[key]
650
651 for key in batch.keys():
652 if key not in batch_uc and isinstance(batch[key], torch.Tensor):
653 batch_uc[key] = torch.clone(batch[key])
654 return batch, batch_uc
655
656
657@torch.no_grad()

Callers 2

do_sampleFunction · 0.70
do_img2imgFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected