(self)
| 465 | assert all(call["num_shards"] is None for call in forwarded_calls) |
| 466 | |
| 467 | def test_sort(self): |
| 468 | with tempfile.TemporaryDirectory() as tmp_dir: |
| 469 | dsets = self._create_dummy_dataset_dict() |
| 470 | |
| 471 | sorted_dsets_1: DatasetDict = dsets.sort("filename") |
| 472 | self.assertListEqual(list(dsets.keys()), list(sorted_dsets_1.keys())) |
| 473 | self.assertListEqual( |
| 474 | [f.split("_")[-1] for f in sorted_dsets_1["train"]["filename"]], |
| 475 | sorted(f"{x:03d}" for x in range(30)), |
| 476 | ) |
| 477 | |
| 478 | indices_cache_file_names = { |
| 479 | "train": os.path.join(tmp_dir, "train.arrow"), |
| 480 | "test": os.path.join(tmp_dir, "test.arrow"), |
| 481 | } |
| 482 | sorted_dsets_2: DatasetDict = sorted_dsets_1.sort( |
| 483 | "filename", indices_cache_file_names=indices_cache_file_names, reverse=True |
| 484 | ) |
| 485 | self.assertListEqual(list(dsets.keys()), list(sorted_dsets_2.keys())) |
| 486 | self.assertListEqual( |
| 487 | [f.split("_")[-1] for f in sorted_dsets_2["train"]["filename"]], |
| 488 | sorted((f"{x:03d}" for x in range(30)), reverse=True), |
| 489 | ) |
| 490 | del dsets, sorted_dsets_1, sorted_dsets_2 |
| 491 | |
| 492 | def test_shuffle(self): |
| 493 | with tempfile.TemporaryDirectory() as tmp_dir: |
nothing calls this directly
no test coverage detected