Apply a set of default rules to make a fast :class:`InputSource`. Args: input_source_or_dataflow(InputSource | DataFlow): trainer (Trainer): Returns: InputSource
(input_source_or_dataflow, trainer)
| 13 | |
| 14 | |
| 15 | def apply_default_prefetch(input_source_or_dataflow, trainer): |
| 16 | """ |
| 17 | Apply a set of default rules to make a fast :class:`InputSource`. |
| 18 | |
| 19 | Args: |
| 20 | input_source_or_dataflow(InputSource | DataFlow): |
| 21 | trainer (Trainer): |
| 22 | |
| 23 | Returns: |
| 24 | InputSource |
| 25 | """ |
| 26 | if not isinstance(input_source_or_dataflow, InputSource): |
| 27 | # to mimic same behavior of the old trainer interface |
| 28 | if type(trainer) == SimpleTrainer: |
| 29 | input = FeedInput(input_source_or_dataflow) |
| 30 | else: |
| 31 | logger.info("Automatically applying QueueInput on the DataFlow.") |
| 32 | input = QueueInput(input_source_or_dataflow) |
| 33 | else: |
| 34 | input = input_source_or_dataflow |
| 35 | if hasattr(trainer, 'devices'): |
| 36 | towers = trainer.devices |
| 37 | if len(towers) > 1: # seem to only help on >1 GPUs |
| 38 | assert not isinstance(trainer, SimpleTrainer) |
| 39 | |
| 40 | if isinstance(input, QueueInput): |
| 41 | logger.info("Automatically applying StagingInput on the DataFlow.") |
| 42 | input = StagingInput(input) |
| 43 | return input |
| 44 | |
| 45 | |
| 46 | def launch_train_with_config(config, trainer): |
no test coverage detected