MCPcopy
hub / github.com/ibab/tensorflow-wavenet / main

Function main

train.py:188–333  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

186
187
188def main():
189 args = get_arguments()
190
191 try:
192 directories = validate_directories(args)
193 except ValueError as e:
194 print("Some arguments are wrong:")
195 print(str(e))
196 return
197
198 logdir = directories['logdir']
199 restore_from = directories['restore_from']
200
201 # Even if we restored the model, we will treat it as new training
202 # if the trained model is written into an arbitrary location.
203 is_overwritten_training = logdir != restore_from
204
205 with open(args.wavenet_params, 'r') as f:
206 wavenet_params = json.load(f)
207
208 # Create coordinator.
209 coord = tf.train.Coordinator()
210
211 # Load raw waveform from VCTK corpus.
212 with tf.name_scope('create_inputs'):
213 # Allow silence trimming to be skipped by specifying a threshold near
214 # zero.
215 silence_threshold = args.silence_threshold if args.silence_threshold > \
216 EPSILON else None
217 gc_enabled = args.gc_channels is not None
218 reader = AudioReader(
219 args.data_dir,
220 coord,
221 sample_rate=wavenet_params['sample_rate'],
222 gc_enabled=gc_enabled,
223 receptive_field=WaveNetModel.calculate_receptive_field(wavenet_params["filter_width"],
224 wavenet_params["dilations"],
225 wavenet_params["scalar_input"],
226 wavenet_params["initial_filter_width"]),
227 sample_size=args.sample_size,
228 silence_threshold=silence_threshold)
229 audio_batch = reader.dequeue(args.batch_size)
230 if gc_enabled:
231 gc_id_batch = reader.dequeue_gc(args.batch_size)
232 else:
233 gc_id_batch = None
234
235 # Create network.
236 net = WaveNetModel(
237 batch_size=args.batch_size,
238 dilations=wavenet_params["dilations"],
239 filter_width=wavenet_params["filter_width"],
240 residual_channels=wavenet_params["residual_channels"],
241 dilation_channels=wavenet_params["dilation_channels"],
242 skip_channels=wavenet_params["skip_channels"],
243 quantization_channels=wavenet_params["quantization_channels"],
244 use_biases=wavenet_params["use_biases"],
245 scalar_input=wavenet_params["scalar_input"],

Callers 1

train.pyFile · 0.70

Calls 11

dequeueMethod · 0.95
dequeue_gcMethod · 0.95
lossMethod · 0.95
start_threadsMethod · 0.95
AudioReaderClass · 0.90
WaveNetModelClass · 0.90
validate_directoriesFunction · 0.85
loadFunction · 0.85
saveFunction · 0.85
get_argumentsFunction · 0.70

Tested by

no test coverage detected