(self, in_memory)
| 2396 | self.assertNotEqual(d3._fingerprint, d2._fingerprint) |
| 2397 | |
| 2398 | def test_sort(self, in_memory): |
| 2399 | with tempfile.TemporaryDirectory() as tmp_dir: |
| 2400 | # Sort on a single key |
| 2401 | with self._create_dummy_dataset(in_memory=in_memory, tmp_dir=tmp_dir) as dset: |
| 2402 | # Keep only 10 examples |
| 2403 | tmp_file = os.path.join(tmp_dir, "test.arrow") |
| 2404 | with dset.select(range(10), indices_cache_file_name=tmp_file) as dset: |
| 2405 | tmp_file = os.path.join(tmp_dir, "test_2.arrow") |
| 2406 | with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset: |
| 2407 | self.assertEqual(len(dset), 10) |
| 2408 | self.assertEqual(dset[0]["filename"], "my_name-train_8") |
| 2409 | self.assertEqual(dset[1]["filename"], "my_name-train_9") |
| 2410 | # Sort |
| 2411 | tmp_file = os.path.join(tmp_dir, "test_3.arrow") |
| 2412 | fingerprint = dset._fingerprint |
| 2413 | with dset.sort("filename", indices_cache_file_name=tmp_file) as dset_sorted: |
| 2414 | for i, row in enumerate(dset_sorted): |
| 2415 | self.assertEqual(int(row["filename"][-1]), i) |
| 2416 | self.assertDictEqual(dset.features, Features({"filename": Value("string")})) |
| 2417 | self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")})) |
| 2418 | self.assertNotEqual(dset_sorted._fingerprint, fingerprint) |
| 2419 | # Sort reversed |
| 2420 | tmp_file = os.path.join(tmp_dir, "test_4.arrow") |
| 2421 | fingerprint = dset._fingerprint |
| 2422 | with dset.sort("filename", indices_cache_file_name=tmp_file, reverse=True) as dset_sorted: |
| 2423 | for i, row in enumerate(dset_sorted): |
| 2424 | self.assertEqual(int(row["filename"][-1]), len(dset_sorted) - 1 - i) |
| 2425 | self.assertDictEqual(dset.features, Features({"filename": Value("string")})) |
| 2426 | self.assertDictEqual(dset_sorted.features, Features({"filename": Value("string")})) |
| 2427 | self.assertNotEqual(dset_sorted._fingerprint, fingerprint) |
| 2428 | # formatted |
| 2429 | dset.set_format("numpy") |
| 2430 | with dset.sort("filename") as dset_sorted_formatted: |
| 2431 | self.assertEqual(dset_sorted_formatted.format["type"], "numpy") |
| 2432 | # Sort on multiple keys |
| 2433 | with self._create_dummy_dataset(in_memory=in_memory, tmp_dir=tmp_dir, multiple_columns=True) as dset: |
| 2434 | tmp_file = os.path.join(tmp_dir, "test_5.arrow") |
| 2435 | fingerprint = dset._fingerprint |
| 2436 | # Throw error when reverse is a list of bools that does not match the length of column_names |
| 2437 | with pytest.raises(ValueError): |
| 2438 | dset.sort(["col_1", "col_2", "col_3"], reverse=[False]) |
| 2439 | with dset.shuffle(seed=1234, indices_cache_file_name=tmp_file) as dset: |
| 2440 | # Sort |
| 2441 | with dset.sort(["col_1", "col_2", "col_3"], reverse=[False, True, False]) as dset_sorted: |
| 2442 | for i, row in enumerate(dset_sorted): |
| 2443 | self.assertEqual(row["col_1"], i) |
| 2444 | self.assertDictEqual( |
| 2445 | dset.features, |
| 2446 | Features( |
| 2447 | { |
| 2448 | "col_1": Value("int64"), |
| 2449 | "col_2": Value("string"), |
| 2450 | "col_3": Value("bool"), |
| 2451 | } |
| 2452 | ), |
| 2453 | ) |
| 2454 | self.assertDictEqual( |
| 2455 | dset_sorted.features, |
nothing calls this directly
no test coverage detected