(keys, value_dict, N: Union[List, ListConfig], device="cuda")
| 171 | |
| 172 | |
| 173 | def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): |
| 174 | # Hardcoded demo setups; might undergo some changes in the future |
| 175 | |
| 176 | batch = {} |
| 177 | batch_uc = {} |
| 178 | |
| 179 | for key in keys: |
| 180 | if key == "txt": |
| 181 | batch["txt"] = ( |
| 182 | np.repeat([value_dict["prompt"]], repeats=math.prod(N)) |
| 183 | .reshape(N) |
| 184 | .tolist() |
| 185 | ) |
| 186 | batch_uc["txt"] = ( |
| 187 | np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) |
| 188 | .reshape(N) |
| 189 | .tolist() |
| 190 | ) |
| 191 | elif key == "original_size_as_tuple": |
| 192 | batch["original_size_as_tuple"] = ( |
| 193 | torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) |
| 194 | .to(device) |
| 195 | .repeat(*N, 1) |
| 196 | ) |
| 197 | elif key == "crop_coords_top_left": |
| 198 | batch["crop_coords_top_left"] = ( |
| 199 | torch.tensor( |
| 200 | [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] |
| 201 | ) |
| 202 | .to(device) |
| 203 | .repeat(*N, 1) |
| 204 | ) |
| 205 | elif key == "aesthetic_score": |
| 206 | batch["aesthetic_score"] = ( |
| 207 | torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) |
| 208 | ) |
| 209 | batch_uc["aesthetic_score"] = ( |
| 210 | torch.tensor([value_dict["negative_aesthetic_score"]]) |
| 211 | .to(device) |
| 212 | .repeat(*N, 1) |
| 213 | ) |
| 214 | |
| 215 | elif key == "target_size_as_tuple": |
| 216 | batch["target_size_as_tuple"] = ( |
| 217 | torch.tensor([value_dict["target_height"], value_dict["target_width"]]) |
| 218 | .to(device) |
| 219 | .repeat(*N, 1) |
| 220 | ) |
| 221 | else: |
| 222 | batch[key] = value_dict[key] |
| 223 | |
| 224 | for key in batch.keys(): |
| 225 | if key not in batch_uc and isinstance(batch[key], torch.Tensor): |
| 226 | batch_uc[key] = torch.clone(batch[key]) |
| 227 | return batch, batch_uc |
| 228 | |
| 229 | |
| 230 | def get_input_image_tensor(image: Image.Image, device="cuda"): |
no outgoing calls
no test coverage detected