split the dataset into num_splits sub-datasets Args: num_splits: specified number of splits Returns: List[RLHFDataset]: list of RLHFDataset splits Raises: ValueError: if num_splits is not a positive integer
(self, num_splits: int)
| 477 | return cls._process_multi_modal_info(messages, image_patch_size=image_patch_size, config=config) |
| 478 | |
| 479 | def split(self, num_splits: int): |
| 480 | """ |
| 481 | split the dataset into num_splits sub-datasets |
| 482 | Args: |
| 483 | num_splits: specified number of splits |
| 484 | Returns: |
| 485 | List[RLHFDataset]: list of RLHFDataset splits |
| 486 | Raises: |
| 487 | ValueError: if num_splits is not a positive integer |
| 488 | """ |
| 489 | if not isinstance(num_splits, int) or num_splits <= 0: |
| 490 | raise ValueError(f"num_splits must be a positive integer, got {num_splits}") |
| 491 | |
| 492 | if not hasattr(self, "dataframe"): |
| 493 | raise AttributeError( |
| 494 | "dataframe not found in RLHFDataset\n" |
| 495 | "reason: _read_files_and_tokenize() not called or Parquet file loading failed" |
| 496 | ) |
| 497 | if self.dataframe is None: |
| 498 | raise ValueError("RLHFDataset dataframe 为 None!") |
| 499 | |
| 500 | total_samples = len(self.dataframe) |
| 501 | print(f"total_samples: {total_samples}") |
| 502 | if total_samples == 0: |
| 503 | raise ValueError("Cannot split an empty dataset") |
| 504 | |
| 505 | # Calculate effective sample count after dropping remainders if needed |
| 506 | if total_samples % num_splits != 0: |
| 507 | total_samples = total_samples - (total_samples % num_splits) |
| 508 | logging.warning(f"Dropping {len(self.dataframe) % num_splits} samples, effective samples: {total_samples}") |
| 509 | |
| 510 | split_size = total_samples // num_splits |
| 511 | splits = [] |
| 512 | |
| 513 | for i in range(num_splits): |
| 514 | start_idx = i * split_size |
| 515 | end_idx = (i + 1) * split_size if i < num_splits - 1 else total_samples |
| 516 | |
| 517 | split_dataframe = self.dataframe.select(range(start_idx, end_idx)) |
| 518 | |
| 519 | split_dataset = RLHFDataset( |
| 520 | data_files=self.data_files, |
| 521 | tokenizer=self.tokenizer, |
| 522 | config=self.config, |
| 523 | processor=self.processor, |
| 524 | max_samples=self.max_samples, |
| 525 | ) |
| 526 | split_dataset.dataframe = split_dataframe |
| 527 | split_dataset.serialize_dataset = self.serialize_dataset |
| 528 | split_dataset.original_data_files = self.original_data_files |
| 529 | |
| 530 | splits.append(split_dataset) |
| 531 | |
| 532 | return splits |
| 533 | |
| 534 | |
| 535 | def get_dataset_class(data_config: DictConfig): |