MCPcopy
hub / github.com/ibab/tensorflow-wavenet / _create_network

Method _create_network

wavenet/model.py:395–442  ·  view source on GitHub ↗

Construct the WaveNet network.

(self, input_batch, global_condition_batch)

Source from the content-addressed store, hash-verified

393 return skip_contribution, input_batch + transformed
394
395 def _create_network(self, input_batch, global_condition_batch):
396 '''Construct the WaveNet network.'''
397 outputs = []
398 current_layer = input_batch
399
400 # Pre-process the input with a regular convolution
401 current_layer = self._create_causal_layer(current_layer)
402
403 output_width = tf.shape(input_batch)[1] - self.receptive_field + 1
404
405 # Add all defined dilation layers.
406 with tf.name_scope('dilated_stack'):
407 for layer_index, dilation in enumerate(self.dilations):
408 with tf.name_scope('layer{}'.format(layer_index)):
409 output, current_layer = self._create_dilation_layer(
410 current_layer, layer_index, dilation,
411 global_condition_batch, output_width)
412 outputs.append(output)
413
414 with tf.name_scope('postprocessing'):
415 # Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to
416 # postprocess the output.
417 w1 = self.variables['postprocessing']['postprocess1']
418 w2 = self.variables['postprocessing']['postprocess2']
419 if self.use_biases:
420 b1 = self.variables['postprocessing']['postprocess1_bias']
421 b2 = self.variables['postprocessing']['postprocess2_bias']
422
423 if self.histograms:
424 tf.histogram_summary('postprocess1_weights', w1)
425 tf.histogram_summary('postprocess2_weights', w2)
426 if self.use_biases:
427 tf.histogram_summary('postprocess1_biases', b1)
428 tf.histogram_summary('postprocess2_biases', b2)
429
430 # We skip connections from the outputs of each layer, adding them
431 # all up here.
432 total = sum(outputs)
433 transformed1 = tf.nn.relu(total)
434 conv1 = tf.nn.conv1d(transformed1, w1, stride=1, padding="SAME")
435 if self.use_biases:
436 conv1 = tf.add(conv1, b1)
437 transformed2 = tf.nn.relu(conv1)
438 conv2 = tf.nn.conv1d(transformed2, w2, stride=1, padding="SAME")
439 if self.use_biases:
440 conv2 = tf.add(conv2, b2)
441
442 return conv2
443
444 def _create_generator(self, input_batch, global_condition_batch):
445 '''Construct an efficient incremental generator.'''

Callers 2

predict_probaMethod · 0.95
lossMethod · 0.95

Calls 2

_create_causal_layerMethod · 0.95

Tested by

no test coverage detected