(reference_paths, vae, image_size)
| 169 | |
| 170 | |
| 171 | def collect_references_batch(reference_paths, vae, image_size): |
| 172 | refs_x = [] # refs_x: [batch, ref_num, C, T, H, W] |
| 173 | for reference_path in reference_paths: |
| 174 | if reference_path == "": |
| 175 | refs_x.append([]) |
| 176 | continue |
| 177 | ref_path = reference_path.split(";") |
| 178 | ref = [] |
| 179 | for r_path in ref_path: |
| 180 | r = read_from_path(r_path, image_size, transform_name="resize_crop") |
| 181 | |
| 182 | # need to ensure r has length accepted by vae |
| 183 | actual_t = r.size(1) |
| 184 | if vae.micro_frame_size is None: |
| 185 | target_t = (actual_t - 1) // 4 * 4 + 1 |
| 186 | elif not vae.temporal_overlap: |
| 187 | target_t = actual_t // vae.micro_frame_size * vae.micro_frame_size |
| 188 | else: |
| 189 | target_t = (actual_t - 1) // (vae.micro_frame_size - 1) * (vae.micro_frame_size - 1) + 1 |
| 190 | r = r[:, :target_t] |
| 191 | |
| 192 | r_x = vae.encode(r.unsqueeze(0).to(vae.device, vae.dtype)) |
| 193 | r_x = r_x.squeeze(0) |
| 194 | ref.append(r_x) |
| 195 | refs_x.append(ref) |
| 196 | return refs_x |
| 197 | |
| 198 | |
| 199 | def extract_images_from_ref_paths(reference_paths, image_size): |
no test coverage detected