| 302 | |
| 303 | |
| 304 | class WebDatasetReader(IterableDataReader): |
| 305 | def __init__( |
| 306 | self, |
| 307 | manifests: List[Tuple[str, str, int, float]], |
| 308 | evaluation: bool = False, |
| 309 | shuffle_buffer_size: int = 20000, |
| 310 | sample_rate: int = 24000, |
| 311 | ): |
| 312 | self.shuffle_buffer_size = shuffle_buffer_size |
| 313 | self.evaluation = evaluation |
| 314 | self.epoch = 0 |
| 315 | |
| 316 | self.orig_urls = [] |
| 317 | self.tar_to_label = {} |
| 318 | self.num_items = 0 |
| 319 | self.num_seconds = 0.0 |
| 320 | for tar_path, label_jsonl_path, num_items, num_seconds in manifests: |
| 321 | self.orig_urls.append(tar_path) |
| 322 | self.tar_to_label[tar_path] = label_jsonl_path |
| 323 | self.num_items += num_items |
| 324 | self.num_seconds += num_seconds |
| 325 | self.urls = self.orig_urls.copy() |
| 326 | self.sample_decoder = SampleDecoder( |
| 327 | tar_to_label=self.tar_to_label, |
| 328 | sample_rate=sample_rate, |
| 329 | ) |
| 330 | self.sample_rate = sample_rate |
| 331 | |
| 332 | def set_epoch(self, epoch: int): |
| 333 | """ |
| 334 | Set the epoch for shuffling. |
| 335 | """ |
| 336 | self.epoch = epoch |
| 337 | self.urls = self.orig_urls.copy() |
| 338 | if not self.evaluation: |
| 339 | random.Random(epoch).shuffle(self.urls) |
| 340 | |
| 341 | def __iter__(self) -> Iterator[Dict[str, Any]]: |
| 342 | |
| 343 | dataset = wds.WebDataset( |
| 344 | self.urls, |
| 345 | shardshuffle=False, |
| 346 | workersplitter=wds.split_by_worker, |
| 347 | nodesplitter=wds.split_by_node, |
| 348 | ) |
| 349 | |
| 350 | pipeline = dataset.decode().map(self.sample_decoder) |
| 351 | if not self.evaluation: |
| 352 | pipeline = pipeline.shuffle(self.shuffle_buffer_size, seed=self.epoch) |
| 353 | return iter(pipeline) |
| 354 | |
| 355 | def __len__(self) -> int: |
| 356 | return self.num_items |
| 357 | |
| 358 | |
| 359 | class JsonlDatasetReader(IterableDataReader): |
no outgoing calls
no test coverage detected