MCPcopy
hub / github.com/ibab/tensorflow-wavenet / testEndToEndTraining

Method testEndToEndTraining

test/test_model.py:206–298  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

204 # training, and learns this waveform.
205
206 def testEndToEndTraining(self):
207 def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder,
208 gc_placeholder, i):
209 speaker_index = 0
210 if speaker_ids is None:
211 # No global conditioning.
212 feed_dict = {audio_placeholder: audio}
213 else:
214 feed_dict = {audio_placeholder: audio,
215 gc_placeholder: speaker_ids}
216 return feed_dict, speaker_index
217
218 np.random.seed(42)
219 audio, speaker_ids = make_sine_waves(self.global_conditioning)
220 # Pad with 0s (silence) times size of the receptive field minus one,
221 # because the first sample of the training data is 0 and if the network
222 # learns to predict silence based on silence, it will generate only
223 # silence.
224 if self.global_conditioning:
225 audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)),
226 'constant')
227 else:
228 audio = np.pad(audio, (self.net.receptive_field - 1, 0),
229 'constant')
230
231 audio_placeholder = tf.placeholder(dtype=tf.float32)
232 gc_placeholder = tf.placeholder(dtype=tf.int32) \
233 if self.global_conditioning else None
234
235 loss = self.net.loss(input_batch=audio_placeholder,
236 global_condition_batch=gc_placeholder)
237 optimizer = optimizer_factory[self.optimizer_type](
238 learning_rate=self.learning_rate, momentum=self.momentum)
239 trainable = tf.trainable_variables()
240 optim = optimizer.minimize(loss, var_list=trainable)
241 init = tf.global_variables_initializer()
242
243 generated_waveform = None
244 max_allowed_loss = 0.1
245 loss_val = max_allowed_loss
246 initial_loss = None
247 operations = [loss, optim]
248 with self.test_session() as sess:
249 feed_dict, speaker_index = CreateTrainingFeedDict(
250 audio, speaker_ids, audio_placeholder, gc_placeholder, 0)
251 sess.run(init)
252 initial_loss = sess.run(loss, feed_dict=feed_dict)
253 for i in range(self.train_iters):
254 feed_dict, speaker_index = CreateTrainingFeedDict(
255 audio, speaker_ids, audio_placeholder, gc_placeholder, i)
256 [results] = sess.run([operations], feed_dict=feed_dict)
257 if i % 100 == 0:
258 print("i: %d loss: %f" % (i, results[0]))
259
260 loss_val = results[0]
261
262 # Sanity check the initial loss was larger.
263 self.assertGreater(initial_loss, max_allowed_loss)

Callers

nothing calls this directly

Calls 4

make_sine_wavesFunction · 0.85
generate_waveformsFunction · 0.85
check_waveformFunction · 0.85
lossMethod · 0.80

Tested by

no test coverage detected