MCPcopy
hub / github.com/hpcaitech/Open-Sora / collect_references_batch

Function collect_references_batch

opensora/utils/inference_utils.py:171–196  ·  view source on GitHub ↗
(reference_paths, vae, image_size)

Source from the content-addressed store, hash-verified

169
170
171def collect_references_batch(reference_paths, vae, image_size):
172 refs_x = [] # refs_x: [batch, ref_num, C, T, H, W]
173 for reference_path in reference_paths:
174 if reference_path == "":
175 refs_x.append([])
176 continue
177 ref_path = reference_path.split(";")
178 ref = []
179 for r_path in ref_path:
180 r = read_from_path(r_path, image_size, transform_name="resize_crop")
181
182 # need to ensure r has length accepted by vae
183 actual_t = r.size(1)
184 if vae.micro_frame_size is None:
185 target_t = (actual_t - 1) // 4 * 4 + 1
186 elif not vae.temporal_overlap:
187 target_t = actual_t // vae.micro_frame_size * vae.micro_frame_size
188 else:
189 target_t = (actual_t - 1) // (vae.micro_frame_size - 1) * (vae.micro_frame_size - 1) + 1
190 r = r[:, :target_t]
191
192 r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype))
193 r_x = r_x.squeeze(0)
194 ref.append(r_x)
195 refs_x.append(ref)
196 return refs_x
197
198
199def extract_images_from_ref_paths(reference_paths, image_size):

Callers 3

mainFunction · 0.90
mainFunction · 0.90
run_inferenceFunction · 0.90

Calls 3

read_from_pathFunction · 0.90
toMethod · 0.80
encodeMethod · 0.45

Tested by

no test coverage detected