MCPcopy
hub / github.com/Stability-AI/generative-models / get_batch

Function get_batch

sgm/inference/helpers.py:173–227  ·  view source on GitHub ↗
(keys, value_dict, N: Union[List, ListConfig], device="cuda")

Source from the content-addressed store, hash-verified

171
172
173def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
174 # Hardcoded demo setups; might undergo some changes in the future
175
176 batch = {}
177 batch_uc = {}
178
179 for key in keys:
180 if key == "txt":
181 batch["txt"] = (
182 np.repeat([value_dict["prompt"]], repeats=math.prod(N))
183 .reshape(N)
184 .tolist()
185 )
186 batch_uc["txt"] = (
187 np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
188 .reshape(N)
189 .tolist()
190 )
191 elif key == "original_size_as_tuple":
192 batch["original_size_as_tuple"] = (
193 torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
194 .to(device)
195 .repeat(*N, 1)
196 )
197 elif key == "crop_coords_top_left":
198 batch["crop_coords_top_left"] = (
199 torch.tensor(
200 [value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
201 )
202 .to(device)
203 .repeat(*N, 1)
204 )
205 elif key == "aesthetic_score":
206 batch["aesthetic_score"] = (
207 torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
208 )
209 batch_uc["aesthetic_score"] = (
210 torch.tensor([value_dict["negative_aesthetic_score"]])
211 .to(device)
212 .repeat(*N, 1)
213 )
214
215 elif key == "target_size_as_tuple":
216 batch["target_size_as_tuple"] = (
217 torch.tensor([value_dict["target_height"], value_dict["target_width"]])
218 .to(device)
219 .repeat(*N, 1)
220 )
221 else:
222 batch[key] = value_dict[key]
223
224 for key in batch.keys():
225 if key not in batch_uc and isinstance(batch[key], torch.Tensor):
226 batch_uc[key] = torch.clone(batch[key])
227 return batch, batch_uc
228
229
230def get_input_image_tensor(image: Image.Image, device="cuda"):

Callers 2

do_sampleFunction · 0.70
do_img2imgFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected