()
| 223 | |
| 224 | |
| 225 | def train(): |
| 226 | # assign GPUs for training & inference |
| 227 | num_gpu = get_num_gpu() |
| 228 | global PREDICTOR_THREAD |
| 229 | if num_gpu > 0: |
| 230 | if num_gpu > 1: |
| 231 | # use half gpus for inference |
| 232 | predict_tower = list(range(num_gpu))[-num_gpu // 2:] |
| 233 | else: |
| 234 | predict_tower = [0] |
| 235 | PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU |
| 236 | train_tower = list(range(num_gpu))[:-num_gpu // 2] or [0] |
| 237 | logger.info("[Batch-A3C] Train on gpu {} and infer on gpu {}".format( |
| 238 | ','.join(map(str, train_tower)), ','.join(map(str, predict_tower)))) |
| 239 | else: |
| 240 | logger.warn("Without GPU this model will never learn! CPU is only useful for debug.") |
| 241 | PREDICTOR_THREAD = 1 |
| 242 | predict_tower, train_tower = [0], [0] |
| 243 | |
| 244 | # setup simulator processes |
| 245 | name_base = str(uuid.uuid1())[:6] |
| 246 | prefix = '@' if sys.platform.startswith('linux') else '' |
| 247 | namec2s = 'ipc://{}sim-c2s-{}'.format(prefix, name_base) |
| 248 | names2c = 'ipc://{}sim-s2c-{}'.format(prefix, name_base) |
| 249 | procs = [MySimulatorWorker(k, namec2s, names2c) for k in range(SIMULATOR_PROC)] |
| 250 | ensure_proc_terminate(procs) |
| 251 | start_proc_mask_signal(procs) |
| 252 | |
| 253 | master = MySimulatorMaster(namec2s, names2c, predict_tower) |
| 254 | config = TrainConfig( |
| 255 | model=Model(), |
| 256 | dataflow=master.get_training_dataflow(), |
| 257 | callbacks=[ |
| 258 | ModelSaver(), |
| 259 | ScheduledHyperParamSetter('learning_rate', [(20, 0.0003), (120, 0.0001)]), |
| 260 | ScheduledHyperParamSetter('entropy_beta', [(80, 0.005)]), |
| 261 | master, |
| 262 | PeriodicTrigger(Evaluator( |
| 263 | EVAL_EPISODE, ['state'], ['policy'], get_player), |
| 264 | every_k_epochs=3), |
| 265 | ], |
| 266 | session_creator=sesscreate.NewSessionCreator(config=get_default_sess_config(0.5)), |
| 267 | steps_per_epoch=STEPS_PER_EPOCH, |
| 268 | session_init=SmartInit(args.load), |
| 269 | max_epoch=1000, |
| 270 | ) |
| 271 | trainer = SimpleTrainer() if num_gpu == 1 else AsyncMultiGPUTrainer(train_tower) |
| 272 | launch_train_with_config(config, trainer) |
| 273 | |
| 274 | |
| 275 | if __name__ == '__main__': |
no test coverage detected