Creates a single causal dilated convolution layer. Args: input_batch: Input to the dilation layer. layer_index: Integer indicating which layer this is. dilation: Integer specifying the dilation size. global_conditioning_batch: Tensor conta
(self, input_batch, layer_index, dilation,
global_condition_batch, output_width)
| 243 | return causal_conv(input_batch, weights_filter, 1) |
| 244 | |
| 245 | def _create_dilation_layer(self, input_batch, layer_index, dilation, |
| 246 | global_condition_batch, output_width): |
| 247 | '''Creates a single causal dilated convolution layer. |
| 248 | |
| 249 | Args: |
| 250 | input_batch: Input to the dilation layer. |
| 251 | layer_index: Integer indicating which layer this is. |
| 252 | dilation: Integer specifying the dilation size. |
| 253 | global_conditioning_batch: Tensor containing the global data upon |
| 254 | which the output is to be conditioned upon. Shape: |
| 255 | [batch size, 1, channels]. The 1 is for the axis |
| 256 | corresponding to time so that the result is broadcast to |
| 257 | all time steps. |
| 258 | |
| 259 | The layer contains a gated filter that connects to dense output |
| 260 | and to a skip connection: |
| 261 | |
| 262 | |-> [gate] -| |-> 1x1 conv -> skip output |
| 263 | | |-> (*) -| |
| 264 | input -|-> [filter] -| |-> 1x1 conv -| |
| 265 | | |-> (+) -> dense output |
| 266 | |------------------------------------| |
| 267 | |
| 268 | Where `[gate]` and `[filter]` are causal convolutions with a |
| 269 | non-linear activation at the output. Biases and global conditioning |
| 270 | are omitted due to the limits of ASCII art. |
| 271 | |
| 272 | ''' |
| 273 | variables = self.variables['dilated_stack'][layer_index] |
| 274 | |
| 275 | weights_filter = variables['filter'] |
| 276 | weights_gate = variables['gate'] |
| 277 | |
| 278 | conv_filter = causal_conv(input_batch, weights_filter, dilation) |
| 279 | conv_gate = causal_conv(input_batch, weights_gate, dilation) |
| 280 | |
| 281 | if global_condition_batch is not None: |
| 282 | weights_gc_filter = variables['gc_filtweights'] |
| 283 | conv_filter = conv_filter + tf.nn.conv1d(global_condition_batch, |
| 284 | weights_gc_filter, |
| 285 | stride=1, |
| 286 | padding="SAME", |
| 287 | name="gc_filter") |
| 288 | weights_gc_gate = variables['gc_gateweights'] |
| 289 | conv_gate = conv_gate + tf.nn.conv1d(global_condition_batch, |
| 290 | weights_gc_gate, |
| 291 | stride=1, |
| 292 | padding="SAME", |
| 293 | name="gc_gate") |
| 294 | |
| 295 | if self.use_biases: |
| 296 | filter_bias = variables['filter_bias'] |
| 297 | gate_bias = variables['gate_bias'] |
| 298 | conv_filter = tf.add(conv_filter, filter_bias) |
| 299 | conv_gate = tf.add(conv_gate, gate_bias) |
| 300 | |
| 301 | out = tf.tanh(conv_filter) * tf.sigmoid(conv_gate) |
| 302 |
no test coverage detected