MCPcopy
hub / github.com/ibab/tensorflow-wavenet / loss

Method loss

wavenet/model.py:617–682  ·  view source on GitHub ↗

Creates a WaveNet network and returns the autoencoding loss. The variables are all scoped to the given name.

(self,
             input_batch,
             global_condition_batch=None,
             l2_regularization_strength=None,
             name='wavenet')

Source from the content-addressed store, hash-verified

615 return tf.reshape(last, [-1])
616
617 def loss(self,
618 input_batch,
619 global_condition_batch=None,
620 l2_regularization_strength=None,
621 name='wavenet'):
622 '''Creates a WaveNet network and returns the autoencoding loss.
623
624 The variables are all scoped to the given name.
625 '''
626 with tf.name_scope(name):
627 # We mu-law encode and quantize the input audioform.
628 encoded_input = mu_law_encode(input_batch,
629 self.quantization_channels)
630
631 gc_embedding = self._embed_gc(global_condition_batch)
632 encoded = self._one_hot(encoded_input)
633 if self.scalar_input:
634 network_input = tf.reshape(
635 tf.cast(input_batch, tf.float32),
636 [self.batch_size, -1, 1])
637 else:
638 network_input = encoded
639
640 # Cut off the last sample of network input to preserve causality.
641 network_input_width = tf.shape(network_input)[1] - 1
642 network_input = tf.slice(network_input, [0, 0, 0],
643 [-1, network_input_width, -1])
644
645 raw_output = self._create_network(network_input, gc_embedding)
646
647 with tf.name_scope('loss'):
648 # Cut off the samples corresponding to the receptive field
649 # for the first predicted sample.
650 target_output = tf.slice(
651 tf.reshape(
652 encoded,
653 [self.batch_size, -1, self.quantization_channels]),
654 [0, self.receptive_field, 0],
655 [-1, -1, -1])
656 target_output = tf.reshape(target_output,
657 [-1, self.quantization_channels])
658 prediction = tf.reshape(raw_output,
659 [-1, self.quantization_channels])
660 loss = tf.nn.softmax_cross_entropy_with_logits(
661 logits=prediction,
662 labels=target_output)
663 reduced_loss = tf.reduce_mean(loss)
664
665 tf.summary.scalar('loss', reduced_loss)
666
667 if l2_regularization_strength is None:
668 return reduced_loss
669 else:
670 # L2 regularization for all trainable parameters
671 l2_loss = tf.add_n([tf.nn.l2_loss(v)
672 for v in tf.trainable_variables()
673 if not('bias' in v.name)])
674

Callers 2

mainFunction · 0.95
testEndToEndTrainingMethod · 0.80

Calls 4

_embed_gcMethod · 0.95
_one_hotMethod · 0.95
_create_networkMethod · 0.95
mu_law_encodeFunction · 0.85

Tested by 1

testEndToEndTrainingMethod · 0.64