Loads the named dataset from a dataset collection by calling `tfds.load`. Args: dataset: `str`, the dataset name to load. split: which split(s) of the dataset to load. If `None`, will return all splits available for the dataset. loader_kwargs: `dict` (optional), keywor
(
self,
dataset: str,
split: Optional[Tree[splits_lib.SplitArg]] = None,
loader_kwargs: dict[str, Any] | None = None,
)
| 300 | self.loader_kwargs = loader_kwargs |
| 301 | |
| 302 | def load_dataset( |
| 303 | self, |
| 304 | dataset: str, |
| 305 | split: Optional[Tree[splits_lib.SplitArg]] = None, |
| 306 | loader_kwargs: dict[str, Any] | None = None, |
| 307 | ) -> Mapping[str, tf.data.Dataset]: |
| 308 | """Loads the named dataset from a dataset collection by calling `tfds.load`. |
| 309 | |
| 310 | Args: |
| 311 | dataset: `str`, the dataset name to load. |
| 312 | split: which split(s) of the dataset to load. If `None`, will return all |
| 313 | splits available for the dataset. |
| 314 | loader_kwargs: `dict` (optional), keyword arguments to be passed to the |
| 315 | `tfds.load` function. Refer to `tfds.load` documentation for a |
| 316 | comperehensive overview of the different loading options. |
| 317 | |
| 318 | Returns: |
| 319 | A `dict` of {`str`: tf.data.Dataset} for the desided dataset. |
| 320 | |
| 321 | Raises: |
| 322 | KeyError: if trying to load a dataset not included in the collection. |
| 323 | RuntimeError: if `load` return type is not a `dict` or a `list`. |
| 324 | """ |
| 325 | if not dataset: |
| 326 | raise TypeError('You must specify a non-empty dataset to load.') |
| 327 | |
| 328 | loader_kwargs = loader_kwargs or self.loader_kwargs or {} |
| 329 | |
| 330 | # with_info must be False (or it will change the return type of `tfds.load`) |
| 331 | if 'with_info' in loader_kwargs and loader_kwargs['with_info']: |
| 332 | logging.warning('`with_info` cannot be True, setting it to False') |
| 333 | loader_kwargs['with_info'] = False |
| 334 | |
| 335 | try: |
| 336 | dataset_reference = self.datasets[dataset] |
| 337 | except KeyError as e: |
| 338 | raise KeyError( |
| 339 | f'Dataset {dataset} is not included in this collection. ' |
| 340 | f'{self.collection.list_datasets(version=self.requested_version)}' |
| 341 | ) from e |
| 342 | |
| 343 | # If `split` is defined both as argument and in `loader_kwargs`, always keep |
| 344 | # the one defined as argument. |
| 345 | if split: |
| 346 | loader_kwargs['split'] = dataset_reference.get_split(split) |
| 347 | # Make sure we always return a dict of dicts. |
| 348 | if 'split' in loader_kwargs and isinstance(loader_kwargs['split'], str): |
| 349 | loader_kwargs['split'] = [loader_kwargs['split']] |
| 350 | |
| 351 | # Add the data dir from the reference to loader_kwargs if it is defined and |
| 352 | # not overridden in loader_kwargs. |
| 353 | if ( |
| 354 | dataset_reference.data_dir is not None |
| 355 | and 'data_dir' not in loader_kwargs |
| 356 | ): |
| 357 | loader_kwargs['data_dir'] = dataset_reference.data_dir |
| 358 | |
| 359 | load_output = load(dataset_reference.tfds_name(), **loader_kwargs) |