| 440 | |
| 441 | |
| 442 | class VideoDataset(Dataset): |
| 443 | def __init__( |
| 444 | self, |
| 445 | instance_data_root: Optional[str] = None, |
| 446 | dataset_name: Optional[str] = None, |
| 447 | dataset_config_name: Optional[str] = None, |
| 448 | caption_column: str = "text", |
| 449 | video_column: str = "video", |
| 450 | height: int = 480, |
| 451 | width: int = 720, |
| 452 | video_reshape_mode: str = "center", |
| 453 | fps: int = 8, |
| 454 | max_num_frames: int = 49, |
| 455 | skip_frames_start: int = 0, |
| 456 | skip_frames_end: int = 0, |
| 457 | cache_dir: Optional[str] = None, |
| 458 | id_token: Optional[str] = None, |
| 459 | ) -> None: |
| 460 | super().__init__() |
| 461 | |
| 462 | self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None |
| 463 | self.dataset_name = dataset_name |
| 464 | self.dataset_config_name = dataset_config_name |
| 465 | self.caption_column = caption_column |
| 466 | self.video_column = video_column |
| 467 | self.height = height |
| 468 | self.width = width |
| 469 | self.video_reshape_mode = video_reshape_mode |
| 470 | self.fps = fps |
| 471 | self.max_num_frames = max_num_frames |
| 472 | self.skip_frames_start = skip_frames_start |
| 473 | self.skip_frames_end = skip_frames_end |
| 474 | self.cache_dir = cache_dir |
| 475 | self.id_token = id_token or "" |
| 476 | |
| 477 | if dataset_name is not None: |
| 478 | self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub() |
| 479 | else: |
| 480 | self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path() |
| 481 | |
| 482 | self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts] |
| 483 | |
| 484 | self.num_instance_videos = len(self.instance_video_paths) |
| 485 | if self.num_instance_videos != len(self.instance_prompts): |
| 486 | raise ValueError( |
| 487 | f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset." |
| 488 | ) |
| 489 | |
| 490 | self.instance_videos = self._preprocess_data() |
| 491 | |
| 492 | def __len__(self): |
| 493 | return self.num_instance_videos |
| 494 | |
| 495 | def __getitem__(self, index): |
| 496 | return { |
| 497 | "instance_prompt": self.instance_prompts[index], |
| 498 | "instance_video": self.instance_videos[index], |
| 499 | } |