(self)
| 204 | # training, and learns this waveform. |
| 205 | |
| 206 | def testEndToEndTraining(self): |
| 207 | def CreateTrainingFeedDict(audio, speaker_ids, audio_placeholder, |
| 208 | gc_placeholder, i): |
| 209 | speaker_index = 0 |
| 210 | if speaker_ids is None: |
| 211 | # No global conditioning. |
| 212 | feed_dict = {audio_placeholder: audio} |
| 213 | else: |
| 214 | feed_dict = {audio_placeholder: audio, |
| 215 | gc_placeholder: speaker_ids} |
| 216 | return feed_dict, speaker_index |
| 217 | |
| 218 | np.random.seed(42) |
| 219 | audio, speaker_ids = make_sine_waves(self.global_conditioning) |
| 220 | # Pad with 0s (silence) times size of the receptive field minus one, |
| 221 | # because the first sample of the training data is 0 and if the network |
| 222 | # learns to predict silence based on silence, it will generate only |
| 223 | # silence. |
| 224 | if self.global_conditioning: |
| 225 | audio = np.pad(audio, ((0, 0), (self.net.receptive_field - 1, 0)), |
| 226 | 'constant') |
| 227 | else: |
| 228 | audio = np.pad(audio, (self.net.receptive_field - 1, 0), |
| 229 | 'constant') |
| 230 | |
| 231 | audio_placeholder = tf.placeholder(dtype=tf.float32) |
| 232 | gc_placeholder = tf.placeholder(dtype=tf.int32) \ |
| 233 | if self.global_conditioning else None |
| 234 | |
| 235 | loss = self.net.loss(input_batch=audio_placeholder, |
| 236 | global_condition_batch=gc_placeholder) |
| 237 | optimizer = optimizer_factory[self.optimizer_type]( |
| 238 | learning_rate=self.learning_rate, momentum=self.momentum) |
| 239 | trainable = tf.trainable_variables() |
| 240 | optim = optimizer.minimize(loss, var_list=trainable) |
| 241 | init = tf.global_variables_initializer() |
| 242 | |
| 243 | generated_waveform = None |
| 244 | max_allowed_loss = 0.1 |
| 245 | loss_val = max_allowed_loss |
| 246 | initial_loss = None |
| 247 | operations = [loss, optim] |
| 248 | with self.test_session() as sess: |
| 249 | feed_dict, speaker_index = CreateTrainingFeedDict( |
| 250 | audio, speaker_ids, audio_placeholder, gc_placeholder, 0) |
| 251 | sess.run(init) |
| 252 | initial_loss = sess.run(loss, feed_dict=feed_dict) |
| 253 | for i in range(self.train_iters): |
| 254 | feed_dict, speaker_index = CreateTrainingFeedDict( |
| 255 | audio, speaker_ids, audio_placeholder, gc_placeholder, i) |
| 256 | [results] = sess.run([operations], feed_dict=feed_dict) |
| 257 | if i % 100 == 0: |
| 258 | print("i: %d loss: %f" % (i, results[0])) |
| 259 | |
| 260 | loss_val = results[0] |
| 261 | |
| 262 | # Sanity check the initial loss was larger. |
| 263 | self.assertGreater(initial_loss, max_allowed_loss) |
nothing calls this directly
no test coverage detected