(args, logdir1, logdir2)
| 23 | |
| 24 | |
| 25 | def train(args, logdir1, logdir2): |
| 26 | # model |
| 27 | model = Net2() |
| 28 | |
| 29 | # dataflow |
| 30 | df = Net2DataFlow(hp.train2.data_path, hp.train2.batch_size) |
| 31 | |
| 32 | # set logger for event and model saver |
| 33 | logger.set_logger_dir(logdir2) |
| 34 | |
| 35 | # session_conf = tf.ConfigProto( |
| 36 | # gpu_options=tf.GPUOptions( |
| 37 | # allow_growth=True, |
| 38 | # per_process_gpu_memory_fraction=0.6, |
| 39 | # ), |
| 40 | # ) |
| 41 | |
| 42 | session_inits = [] |
| 43 | ckpt2 = '{}/{}'.format(logdir2, args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir2) |
| 44 | if ckpt2: |
| 45 | session_inits.append(SaverRestore(ckpt2)) |
| 46 | ckpt1 = tf.train.latest_checkpoint(logdir1) |
| 47 | if ckpt1: |
| 48 | session_inits.append(SaverRestore(ckpt1, ignore=['global_step'])) |
| 49 | train_conf = TrainConfig( |
| 50 | model=model, |
| 51 | data=QueueInput(df(n_prefetch=1000, n_thread=4)), |
| 52 | callbacks=[ |
| 53 | # TODO save on prefix net2 |
| 54 | ModelSaver(checkpoint_dir=logdir2), |
| 55 | # ConvertCallback(logdir2, hp.train2.test_per_epoch), |
| 56 | ], |
| 57 | max_epoch=hp.train2.num_epochs, |
| 58 | steps_per_epoch=hp.train2.steps_per_epoch, |
| 59 | session_init=ChainInit(session_inits) |
| 60 | ) |
| 61 | if args.gpu: |
| 62 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu |
| 63 | train_conf.nr_tower = len(args.gpu.split(',')) |
| 64 | |
| 65 | trainer = SyncMultiGPUTrainerReplicated(hp.train2.num_gpu) |
| 66 | |
| 67 | launch_train_with_config(train_conf, trainer=trainer) |
| 68 | |
| 69 | |
| 70 | # def get_cyclic_lr(step): |
no test coverage detected