MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / train

Function train

examples/A3C-Gym/train-atari.py:225–272  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

223
224
225def train():
226 # assign GPUs for training & inference
227 num_gpu = get_num_gpu()
228 global PREDICTOR_THREAD
229 if num_gpu > 0:
230 if num_gpu > 1:
231 # use half gpus for inference
232 predict_tower = list(range(num_gpu))[-num_gpu // 2:]
233 else:
234 predict_tower = [0]
235 PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
236 train_tower = list(range(num_gpu))[:-num_gpu // 2] or [0]
237 logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format(
238 ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
239 else:
240 logger.warn("Without GPU this model will never learn! CPU is only useful for debug.")
241 PREDICTOR_THREAD = 1
242 predict_tower, train_tower = [0], [0]
243
244 # setup simulator processes
245 name_base = str(uuid.uuid1())[:6]
246 prefix = '@' if sys.platform.startswith('linux') else ''
247 namec2s = 'ipc://{}sim-c2s-{}'.format(prefix, name_base)
248 names2c = 'ipc://{}sim-s2c-{}'.format(prefix, name_base)
249 procs = [MySimulatorWorker(k, namec2s, names2c) for k in range(SIMULATOR_PROC)]
250 ensure_proc_terminate(procs)
251 start_proc_mask_signal(procs)
252
253 master = MySimulatorMaster(namec2s, names2c, predict_tower)
254 config = TrainConfig(
255 model=Model(),
256 dataflow=master.get_training_dataflow(),
257 callbacks=[
258 ModelSaver(),
259 ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]),
260 ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]),
261 master,
262 PeriodicTrigger(Evaluator(
263 EVAL_EPISODE, ['state'], ['policy'], get_player),
264 every_k_epochs=3),
265 ],
266 session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)),
267 steps_per_epoch=STEPS_PER_EPOCH,
268 session_init=SmartInit(args.load),
269 max_epoch=1000,
270 )
271 trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower)
272 launch_train_with_config(config, trainer)
273
274
275if __name__ == '__main__':

Callers 1

train-atari.pyFile · 0.85

Calls 15

get_training_dataflowMethod · 0.95
get_num_gpuFunction · 0.90
ensure_proc_terminateFunction · 0.90
start_proc_mask_signalFunction · 0.90
EvaluatorClass · 0.90
MySimulatorWorkerClass · 0.85
MySimulatorMasterClass · 0.85
TrainConfigClass · 0.85
ModelSaverClass · 0.85
PeriodicTriggerClass · 0.85
get_default_sess_configFunction · 0.85

Tested by

no test coverage detected