(keys, value_dict, N: Union[List, ListConfig], device="cuda")
| 598 | |
| 599 | |
| 600 | def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): |
| 601 | # Hardcoded demo setups; might undergo some changes in the future |
| 602 | |
| 603 | batch = {} |
| 604 | batch_uc = {} |
| 605 | |
| 606 | for key in keys: |
| 607 | if key == "txt": |
| 608 | batch["txt"] = ( |
| 609 | np.repeat([value_dict["prompt"]], repeats=math.prod(N)) |
| 610 | .reshape(N) |
| 611 | .tolist() |
| 612 | ) |
| 613 | batch_uc["txt"] = ( |
| 614 | np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) |
| 615 | .reshape(N) |
| 616 | .tolist() |
| 617 | ) |
| 618 | elif key == "original_size_as_tuple": |
| 619 | batch["original_size_as_tuple"] = ( |
| 620 | torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) |
| 621 | .to(device) |
| 622 | .repeat(*N, 1) |
| 623 | ) |
| 624 | elif key == "crop_coords_top_left": |
| 625 | batch["crop_coords_top_left"] = ( |
| 626 | torch.tensor( |
| 627 | [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] |
| 628 | ) |
| 629 | .to(device) |
| 630 | .repeat(*N, 1) |
| 631 | ) |
| 632 | elif key == "aesthetic_score": |
| 633 | batch["aesthetic_score"] = ( |
| 634 | torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) |
| 635 | ) |
| 636 | batch_uc["aesthetic_score"] = ( |
| 637 | torch.tensor([value_dict["negative_aesthetic_score"]]) |
| 638 | .to(device) |
| 639 | .repeat(*N, 1) |
| 640 | ) |
| 641 | |
| 642 | elif key == "target_size_as_tuple": |
| 643 | batch["target_size_as_tuple"] = ( |
| 644 | torch.tensor([value_dict["target_height"], value_dict["target_width"]]) |
| 645 | .to(device) |
| 646 | .repeat(*N, 1) |
| 647 | ) |
| 648 | else: |
| 649 | batch[key] = value_dict[key] |
| 650 | |
| 651 | for key in batch.keys(): |
| 652 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): |
| 653 | batch_uc[key] = torch.clone(batch[key]) |
| 654 | return batch, batch_uc |
| 655 | |
| 656 | |
| 657 | @torch.no_grad() |
no outgoing calls
no test coverage detected