Creates a WaveNet network and returns the autoencoding loss. The variables are all scoped to the given name.
(self,
input_batch,
global_condition_batch=None,
l2_regularization_strength=None,
name='wavenet')
| 615 | return tf.reshape(last, [-1]) |
| 616 | |
| 617 | def loss(self, |
| 618 | input_batch, |
| 619 | global_condition_batch=None, |
| 620 | l2_regularization_strength=None, |
| 621 | name='wavenet'): |
| 622 | '''Creates a WaveNet network and returns the autoencoding loss. |
| 623 | |
| 624 | The variables are all scoped to the given name. |
| 625 | ''' |
| 626 | with tf.name_scope(name): |
| 627 | # We mu-law encode and quantize the input audioform. |
| 628 | encoded_input = mu_law_encode(input_batch, |
| 629 | self.quantization_channels) |
| 630 | |
| 631 | gc_embedding = self._embed_gc(global_condition_batch) |
| 632 | encoded = self._one_hot(encoded_input) |
| 633 | if self.scalar_input: |
| 634 | network_input = tf.reshape( |
| 635 | tf.cast(input_batch, tf.float32), |
| 636 | [self.batch_size, -1, 1]) |
| 637 | else: |
| 638 | network_input = encoded |
| 639 | |
| 640 | # Cut off the last sample of network input to preserve causality. |
| 641 | network_input_width = tf.shape(network_input)[1] - 1 |
| 642 | network_input = tf.slice(network_input, [0, 0, 0], |
| 643 | [-1, network_input_width, -1]) |
| 644 | |
| 645 | raw_output = self._create_network(network_input, gc_embedding) |
| 646 | |
| 647 | with tf.name_scope('loss'): |
| 648 | # Cut off the samples corresponding to the receptive field |
| 649 | # for the first predicted sample. |
| 650 | target_output = tf.slice( |
| 651 | tf.reshape( |
| 652 | encoded, |
| 653 | [self.batch_size, -1, self.quantization_channels]), |
| 654 | [0, self.receptive_field, 0], |
| 655 | [-1, -1, -1]) |
| 656 | target_output = tf.reshape(target_output, |
| 657 | [-1, self.quantization_channels]) |
| 658 | prediction = tf.reshape(raw_output, |
| 659 | [-1, self.quantization_channels]) |
| 660 | loss = tf.nn.softmax_cross_entropy_with_logits( |
| 661 | logits=prediction, |
| 662 | labels=target_output) |
| 663 | reduced_loss = tf.reduce_mean(loss) |
| 664 | |
| 665 | tf.summary.scalar('loss', reduced_loss) |
| 666 | |
| 667 | if l2_regularization_strength is None: |
| 668 | return reduced_loss |
| 669 | else: |
| 670 | # L2 regularization for all trainable parameters |
| 671 | l2_loss = tf.add_n([tf.nn.l2_loss(v) |
| 672 | for v in tf.trainable_variables() |
| 673 | if not('bias' in v.name)]) |
| 674 |