Construct the WaveNet network.
(self, input_batch, global_condition_batch)
| 393 | return skip_contribution, input_batch + transformed |
| 394 | |
| 395 | def _create_network(self, input_batch, global_condition_batch): |
| 396 | '''Construct the WaveNet network.''' |
| 397 | outputs = [] |
| 398 | current_layer = input_batch |
| 399 | |
| 400 | # Pre-process the input with a regular convolution |
| 401 | current_layer = self._create_causal_layer(current_layer) |
| 402 | |
| 403 | output_width = tf.shape(input_batch)[1] - self.receptive_field + 1 |
| 404 | |
| 405 | # Add all defined dilation layers. |
| 406 | with tf.name_scope('dilated_stack'): |
| 407 | for layer_index, dilation in enumerate(self.dilations): |
| 408 | with tf.name_scope('layer{}'.format(layer_index)): |
| 409 | output, current_layer = self._create_dilation_layer( |
| 410 | current_layer, layer_index, dilation, |
| 411 | global_condition_batch, output_width) |
| 412 | outputs.append(output) |
| 413 | |
| 414 | with tf.name_scope('postprocessing'): |
| 415 | # Perform (+) -> ReLU -> 1x1 conv -> ReLU -> 1x1 conv to |
| 416 | # postprocess the output. |
| 417 | w1 = self.variables['postprocessing']['postprocess1'] |
| 418 | w2 = self.variables['postprocessing']['postprocess2'] |
| 419 | if self.use_biases: |
| 420 | b1 = self.variables['postprocessing']['postprocess1_bias'] |
| 421 | b2 = self.variables['postprocessing']['postprocess2_bias'] |
| 422 | |
| 423 | if self.histograms: |
| 424 | tf.histogram_summary('postprocess1_weights', w1) |
| 425 | tf.histogram_summary('postprocess2_weights', w2) |
| 426 | if self.use_biases: |
| 427 | tf.histogram_summary('postprocess1_biases', b1) |
| 428 | tf.histogram_summary('postprocess2_biases', b2) |
| 429 | |
| 430 | # We skip connections from the outputs of each layer, adding them |
| 431 | # all up here. |
| 432 | total = sum(outputs) |
| 433 | transformed1 = tf.nn.relu(total) |
| 434 | conv1 = tf.nn.conv1d(transformed1, w1, stride=1, padding="SAME") |
| 435 | if self.use_biases: |
| 436 | conv1 = tf.add(conv1, b1) |
| 437 | transformed2 = tf.nn.relu(conv1) |
| 438 | conv2 = tf.nn.conv1d(transformed2, w2, stride=1, padding="SAME") |
| 439 | if self.use_biases: |
| 440 | conv2 = tf.add(conv2, b2) |
| 441 | |
| 442 | return conv2 |
| 443 | |
| 444 | def _create_generator(self, input_batch, global_condition_batch): |
| 445 | '''Construct an efficient incremental generator.''' |
no test coverage detected