MCPcopy
hub / github.com/andabi/deep-voice-conversion / train

Function train

train2.py:25–67  ·  view source on GitHub ↗
(args, logdir1, logdir2)

Source from the content-addressed store, hash-verified

23
24
25def train(args, logdir1, logdir2):
26 # model
27 model = Net2()
28
29 # dataflow
30 df = Net2DataFlow(hp.train2.data_path, hp.train2.batch_size)
31
32 # set logger for event and model saver
33 logger.set_logger_dir(logdir2)
34
35 # session_conf = tf.ConfigProto(
36 # gpu_options=tf.GPUOptions(
37 # allow_growth=True,
38 # per_process_gpu_memory_fraction=0.6,
39 # ),
40 # )
41
42 session_inits = []
43 ckpt2 = '{}/{}'.format(logdir2, args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir2)
44 if ckpt2:
45 session_inits.append(SaverRestore(ckpt2))
46 ckpt1 = tf.train.latest_checkpoint(logdir1)
47 if ckpt1:
48 session_inits.append(SaverRestore(ckpt1, ignore=['global_step']))
49 train_conf = TrainConfig(
50 model=model,
51 data=QueueInput(df(n_prefetch=1000, n_thread=4)),
52 callbacks=[
53 # TODO save on prefix net2
54 ModelSaver(checkpoint_dir=logdir2),
55 # ConvertCallback(logdir2, hp.train2.test_per_epoch),
56 ],
57 max_epoch=hp.train2.num_epochs,
58 steps_per_epoch=hp.train2.steps_per_epoch,
59 session_init=ChainInit(session_inits)
60 )
61 if args.gpu:
62 os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
63 train_conf.nr_tower = len(args.gpu.split(','))
64
65 trainer = SyncMultiGPUTrainerReplicated(hp.train2.num_gpu)
66
67 launch_train_with_config(train_conf, trainer=trainer)
68
69
70# def get_cyclic_lr(step):

Callers 1

train2.pyFile · 0.70

Calls 2

Net2Class · 0.90
Net2DataFlowClass · 0.90

Tested by

no test coverage detected