| 66 | |
| 67 | |
| 68 | def create_dataset_and_dataloader(args, epoch: int): |
| 69 | # Find data |
| 70 | filename = f"{args.data_prefix}.{epoch}.parquet" |
| 71 | |
| 72 | # Create dataset and dataloader |
| 73 | print(f"Loading epoch {epoch} data from {filename}...") |
| 74 | |
| 75 | dataset = OpenchatDataset( |
| 76 | dataset_filename=filename, |
| 77 | |
| 78 | batch_max_length=args.batch_max_len, |
| 79 | rank=dist.get_rank(), |
| 80 | num_replicas=dist.get_world_size() |
| 81 | ) |
| 82 | dataloader = DataLoader( |
| 83 | dataset, |
| 84 | batch_size=None, |
| 85 | |
| 86 | num_workers=1, |
| 87 | prefetch_factor=8, |
| 88 | |
| 89 | pin_memory=True |
| 90 | ) |
| 91 | return dataset, dataloader |
| 92 | |
| 93 | |
| 94 | def create_model(args): |