Reads the data source (files/tfds) to a dataset.
(
self,
matched_files: Union[Dict[str, List[str]], List[str]],
dataset_fn,
input_context: Optional[tf.distribute.InputContext] = None,
)
| 370 | return matched_files |
| 371 | |
| 372 | def _read_data_source( |
| 373 | self, |
| 374 | matched_files: Union[Dict[str, List[str]], List[str]], |
| 375 | dataset_fn, |
| 376 | input_context: Optional[tf.distribute.InputContext] = None, |
| 377 | ): |
| 378 | """Reads the data source (files/tfds) to a dataset.""" |
| 379 | |
| 380 | def _files_to_dataset(files: List[str]) -> tf.data.Dataset: |
| 381 | if len(files) > 1: |
| 382 | if input_context and (len(files) < input_context.num_input_pipelines): |
| 383 | logging.warn( |
| 384 | ( |
| 385 | 'The number of files %d is less than the number of input ' |
| 386 | 'pipelines %d. We will send all input files to every worker. ' |
| 387 | 'Please consider sharding your data into more files.' |
| 388 | ), |
| 389 | len(files), |
| 390 | input_context.num_input_pipelines, |
| 391 | ) |
| 392 | return _read_files_then_shard( |
| 393 | files, |
| 394 | dataset_fn, |
| 395 | input_context, |
| 396 | sharding=self._sharding, |
| 397 | repeat=self._is_training and not self._cache) |
| 398 | else: |
| 399 | return _shard_files_then_read( |
| 400 | files, |
| 401 | dataset_fn, |
| 402 | input_context, |
| 403 | seed=self._seed, |
| 404 | is_training=self._is_training, |
| 405 | sharding=self._sharding, |
| 406 | cache=self._cache, |
| 407 | cycle_length=self._cycle_length, |
| 408 | block_length=self._block_length, |
| 409 | deterministic=self._deterministic) |
| 410 | elif len(files) == 1: |
| 411 | return _read_files_then_shard( |
| 412 | files, |
| 413 | dataset_fn, |
| 414 | input_context, |
| 415 | sharding=self._sharding, |
| 416 | repeat=self._is_training and not self._cache) |
| 417 | else: |
| 418 | raise ValueError('It is unexpected that `tfds_builder` is None and ' |
| 419 | 'there is also no `files`.') |
| 420 | |
| 421 | if self._tfds_name: |
| 422 | if isinstance(self._tfds_name, cfg.base_config.Config): |
| 423 | dataset = {} |
| 424 | for k, tfds_name in self._tfds_name.as_dict().items(): |
| 425 | dataset[k] = _read_tfds( |
| 426 | tfds_name=tfds_name, |
| 427 | tfds_data_dir=self._tfds_data_dir, |
| 428 | tfds_split=self._tfds_split, |
| 429 | tfds_skip_decoding_feature=self._tfds_skip_decoding_feature, |
no test coverage detected