MCPcopy
hub / github.com/zai-org/CogVideo / VideoDataset

Class VideoDataset

finetune/train_cogvideox_image_to_video_lora.py:442–662  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

440
441
442class VideoDataset(Dataset):
443 def __init__(
444 self,
445 instance_data_root: Optional[str] = None,
446 dataset_name: Optional[str] = None,
447 dataset_config_name: Optional[str] = None,
448 caption_column: str = "text",
449 video_column: str = "video",
450 height: int = 480,
451 width: int = 720,
452 video_reshape_mode: str = "center",
453 fps: int = 8,
454 max_num_frames: int = 49,
455 skip_frames_start: int = 0,
456 skip_frames_end: int = 0,
457 cache_dir: Optional[str] = None,
458 id_token: Optional[str] = None,
459 ) -> None:
460 super().__init__()
461
462 self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
463 self.dataset_name = dataset_name
464 self.dataset_config_name = dataset_config_name
465 self.caption_column = caption_column
466 self.video_column = video_column
467 self.height = height
468 self.width = width
469 self.video_reshape_mode = video_reshape_mode
470 self.fps = fps
471 self.max_num_frames = max_num_frames
472 self.skip_frames_start = skip_frames_start
473 self.skip_frames_end = skip_frames_end
474 self.cache_dir = cache_dir
475 self.id_token = id_token or ""
476
477 if dataset_name is not None:
478 self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
479 else:
480 self.instance_prompts, self.instance_video_paths = self._load_dataset_from_local_path()
481
482 self.instance_prompts = [self.id_token + prompt for prompt in self.instance_prompts]
483
484 self.num_instance_videos = len(self.instance_video_paths)
485 if self.num_instance_videos != len(self.instance_prompts):
486 raise ValueError(
487 f"Expected length of instance prompts and videos to be the same but found {len(self.instance_prompts)=} and {len(self.instance_video_paths)=}. Please ensure that the number of caption prompts and videos match in your dataset."
488 )
489
490 self.instance_videos = self._preprocess_data()
491
492 def __len__(self):
493 return self.num_instance_videos
494
495 def __getitem__(self, index):
496 return {
497 "instance_prompt": self.instance_prompts[index],
498 "instance_video": self.instance_videos[index],
499 }

Callers 1

mainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected