(self, batch, batch_idx)
| 138 | self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} |
| 139 | |
| 140 | def test_step(self, batch, batch_idx): |
| 141 | text, video, path = batch["text"][0], batch["video"], batch["path"][0] |
| 142 | |
| 143 | self.pipe.device = self.device |
| 144 | if video is not None: |
| 145 | pth_path = path + ".tensors.pth" |
| 146 | if not os.path.exists(pth_path): |
| 147 | # prompt |
| 148 | prompt_emb = self.pipe.encode_prompt(text) |
| 149 | # video |
| 150 | video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) |
| 151 | latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] |
| 152 | # image |
| 153 | if "first_frame" in batch: |
| 154 | first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) |
| 155 | _, _, num_frames, height, width = video.shape |
| 156 | image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) |
| 157 | else: |
| 158 | image_emb = {} |
| 159 | data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} |
| 160 | torch.save(data, pth_path) |
| 161 | else: |
| 162 | print(f"File {pth_path} already exists, skipping.") |
| 163 | |
| 164 | class Camera(object): |
| 165 | def __init__(self, c2w): |
nothing calls this directly
no test coverage detected