MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / launch_train_with_config

Function launch_train_with_config

tensorpack/train/interface.py:46–99  ·  view source on GitHub ↗

Train with a :class:`TrainConfig` and a :class:`Trainer`, to present the simple and old training interface. It basically does the following 3 things (and you can easily do them by yourself if you need more control): 1. Setup the input with automatic prefetching heuristics, f

(config, trainer)

Source from the content-addressed store, hash-verified

44
45
46def launch_train_with_config(config, trainer):
47 """
48 Train with a :class:`TrainConfig` and a :class:`Trainer`, to
49 present the simple and old training interface. It basically does the following
50 3 things (and you can easily do them by yourself if you need more control):
51
52 1. Setup the input with automatic prefetching heuristics,
53 from `config.data` or `config.dataflow`.
54 2. Call `trainer.setup_graph` with the input as well as `config.model`.
55 3. Call `trainer.train` with rest of the attributes of config.
56
57 See the `related tutorial
58 <https://tensorpack.readthedocs.io/tutorial/training-interface.html#with-modeldesc-and-trainconfig>`_
59 to learn more.
60
61 Args:
62 config (TrainConfig):
63 trainer (Trainer): an instance of :class:`SingleCostTrainer`.
64
65 Example:
66
67 .. code-block:: python
68
69 launch_train_with_config(
70 config, SyncMultiGPUTrainerParameterServer(8, ps_device='gpu'))
71 """
72 if is_tfv2():
73 tfv1.disable_eager_execution()
74
75 assert isinstance(trainer, SingleCostTrainer), trainer
76 assert isinstance(config, TrainConfig), config
77 assert config.model is not None
78 assert config.dataflow is not None or config.data is not None
79
80 model = config.model
81 input = config.data or config.dataflow
82 input = apply_default_prefetch(input, trainer)
83
84 # This is the only place where the `ModelDesc` abstraction is useful.
85 # We should gradually stay away from this unuseful abstraction.
86 # TowerFunc is a better abstraction (similar to tf.function in the future)
87 trainer.setup_graph(
88 model.get_input_signature(), input,
89 model.build_graph, model.get_optimizer)
90 _check_unused_regularization()
91 trainer.train_with_defaults(
92 callbacks=config.callbacks,
93 monitors=config.monitors,
94 session_creator=config.session_creator,
95 session_init=config.session_init,
96 steps_per_epoch=config.steps_per_epoch,
97 starting_epoch=config.starting_epoch,
98 max_epoch=config.max_epoch,
99 extra_callbacks=config.extra_callbacks)
100
101
102def _check_unused_regularization():

Callers 15

imagenet-resnet.pyFile · 0.90
boilerplate.pyFile · 0.85
hed.pyFile · 0.85
cifar10-resnet.pyFile · 0.85
trainFunction · 0.85
char-rnn.pyFile · 0.85
PTB-LSTM.pyFile · 0.85
CAM-resnet.pyFile · 0.85
steering-filter.pyFile · 0.85
train-timit.pyFile · 0.85
DQN.pyFile · 0.85

Calls 6

is_tfv2Function · 0.85
apply_default_prefetchFunction · 0.85
get_input_signatureMethod · 0.80
train_with_defaultsMethod · 0.80
setup_graphMethod · 0.45

Tested by

no test coverage detected