| 62 | return mock_info |
| 63 | |
| 64 | def load(self, config): |
| 65 | # `datasets.home()` is patched to a temporary directory through the autouse fixture `test_home` in |
| 66 | # test/test_prototype_builtin_datasets.py |
| 67 | root = pathlib.Path(datasets.home()) / self.name |
| 68 | # We cannot place the mock data upfront in `root`. Loading a dataset calls `OnlineResource.load`. In turn, |
| 69 | # this will only download **and** preprocess if the file is not present. In other words, if we already place |
| 70 | # the file in `root` before the resource is loaded, we are effectively skipping the preprocessing. |
| 71 | # To avoid that we first place the mock data in a temporary directory and patch the download logic to move it to |
| 72 | # `root` only when it is requested. |
| 73 | tmp_mock_data_folder = root / "__mock__" |
| 74 | tmp_mock_data_folder.mkdir(parents=True) |
| 75 | |
| 76 | mock_info = self._parse_mock_info(self.mock_data_fn(tmp_mock_data_folder, config)) |
| 77 | |
| 78 | def patched_download(resource, root, **kwargs): |
| 79 | src = tmp_mock_data_folder / resource.file_name |
| 80 | if not src.exists(): |
| 81 | raise pytest.UsageError( |
| 82 | f"Dataset '{self.name}' requires the file {resource.file_name} for {config}" |
| 83 | f"but it was not created by the mock data function." |
| 84 | ) |
| 85 | |
| 86 | dst = root / resource.file_name |
| 87 | shutil.move(str(src), str(root)) |
| 88 | |
| 89 | return dst |
| 90 | |
| 91 | with unittest.mock.patch( |
| 92 | "torchvision.prototype.datasets.utils._resource.OnlineResource.download", new=patched_download |
| 93 | ): |
| 94 | dataset = datasets.load(self.name, **config) |
| 95 | |
| 96 | extra_files = list(tmp_mock_data_folder.glob("**/*")) |
| 97 | if extra_files: |
| 98 | raise pytest.UsageError( |
| 99 | ( |
| 100 | f"Dataset '{self.name}' created the following files for {config} in the mock data function, " |
| 101 | f"but they were not loaded:\n\n" |
| 102 | ) |
| 103 | + "\n".join(str(file.relative_to(tmp_mock_data_folder)) for file in extra_files) |
| 104 | ) |
| 105 | |
| 106 | tmp_mock_data_folder.rmdir() |
| 107 | |
| 108 | return dataset, mock_info |
| 109 | |
| 110 | |
| 111 | def config_id(name, config): |