Train with a :class:`TrainConfig` and a :class:`Trainer`, to present the simple and old training interface. It basically does the following 3 things (and you can easily do them by yourself if you need more control): 1. Setup the input with automatic prefetching heuristics, f
(config, trainer)
| 44 | |
| 45 | |
| 46 | def launch_train_with_config(config, trainer): |
| 47 | """ |
| 48 | Train with a :class:`TrainConfig` and a :class:`Trainer`, to |
| 49 | present the simple and old training interface. It basically does the following |
| 50 | 3 things (and you can easily do them by yourself if you need more control): |
| 51 | |
| 52 | 1. Setup the input with automatic prefetching heuristics, |
| 53 | from `config.data` or `config.dataflow`. |
| 54 | 2. Call `trainer.setup_graph` with the input as well as `config.model`. |
| 55 | 3. Call `trainer.train` with rest of the attributes of config. |
| 56 | |
| 57 | See the `related tutorial |
| 58 | <https://tensorpack.readthedocs.io/tutorial/training-interface.html#with-modeldesc-and-trainconfig>`_ |
| 59 | to learn more. |
| 60 | |
| 61 | Args: |
| 62 | config (TrainConfig): |
| 63 | trainer (Trainer): an instance of :class:`SingleCostTrainer`. |
| 64 | |
| 65 | Example: |
| 66 | |
| 67 | .. code-block:: python |
| 68 | |
| 69 | launch_train_with_config( |
| 70 | config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu')) |
| 71 | """ |
| 72 | if is_tfv2(): |
| 73 | tfv1.disable_eager_execution() |
| 74 | |
| 75 | assert isinstance(trainer, SingleCostTrainer), trainer |
| 76 | assert isinstance(config, TrainConfig), config |
| 77 | assert config.model is not None |
| 78 | assert config.dataflow is not None or config.data is not None |
| 79 | |
| 80 | model = config.model |
| 81 | input = config.data or config.dataflow |
| 82 | input = apply_default_prefetch(input, trainer) |
| 83 | |
| 84 | # This is the only place where the `ModelDesc` abstraction is useful. |
| 85 | # We should gradually stay away from this unuseful abstraction. |
| 86 | # TowerFunc is a better abstraction (similar to tf.function in the future) |
| 87 | trainer.setup_graph( |
| 88 | model.get_input_signature(), input, |
| 89 | model.build_graph, model.get_optimizer) |
| 90 | _check_unused_regularization() |
| 91 | trainer.train_with_defaults( |
| 92 | callbacks=config.callbacks, |
| 93 | monitors=config.monitors, |
| 94 | session_creator=config.session_creator, |
| 95 | session_init=config.session_init, |
| 96 | steps_per_epoch=config.steps_per_epoch, |
| 97 | starting_epoch=config.starting_epoch, |
| 98 | max_epoch=config.max_epoch, |
| 99 | extra_callbacks=config.extra_callbacks) |
| 100 | |
| 101 | |
| 102 | def _check_unused_regularization(): |
no test coverage detected