MCPcopy
hub / github.com/ibab/tensorflow-wavenet / _create_dilation_layer

Method _create_dilation_layer

wavenet/model.py:245–336  ·  view source on GitHub ↗

Creates a single causal dilated convolution layer. Args: input_batch: Input to the dilation layer. layer_index: Integer indicating which layer this is. dilation: Integer specifying the dilation size. global_conditioning_batch: Tensor conta

(self, input_batch, layer_index, dilation,
                               global_condition_batch, output_width)

Source from the content-addressed store, hash-verified

243 return causal_conv(input_batch, weights_filter, 1)
244
245 def _create_dilation_layer(self, input_batch, layer_index, dilation,
246 global_condition_batch, output_width):
247 '''Creates a single causal dilated convolution layer.
248
249 Args:
250 input_batch: Input to the dilation layer.
251 layer_index: Integer indicating which layer this is.
252 dilation: Integer specifying the dilation size.
253 global_conditioning_batch: Tensor containing the global data upon
254 which the output is to be conditioned upon. Shape:
255 [batch size, 1, channels]. The 1 is for the axis
256 corresponding to time so that the result is broadcast to
257 all time steps.
258
259 The layer contains a gated filter that connects to dense output
260 and to a skip connection:
261
262 |-> [gate] -| |-> 1x1 conv -> skip output
263 | |-> (*) -|
264 input -|-> [filter] -| |-> 1x1 conv -|
265 | |-> (+) -> dense output
266 |------------------------------------|
267
268 Where `[gate]` and `[filter]` are causal convolutions with a
269 non-linear activation at the output. Biases and global conditioning
270 are omitted due to the limits of ASCII art.
271
272 '''
273 variables = self.variables['dilated_stack'][layer_index]
274
275 weights_filter = variables['filter']
276 weights_gate = variables['gate']
277
278 conv_filter = causal_conv(input_batch, weights_filter, dilation)
279 conv_gate = causal_conv(input_batch, weights_gate, dilation)
280
281 if global_condition_batch is not None:
282 weights_gc_filter = variables['gc_filtweights']
283 conv_filter = conv_filter + tf.nn.conv1d(global_condition_batch,
284 weights_gc_filter,
285 stride=1,
286 padding="SAME",
287 name="gc_filter")
288 weights_gc_gate = variables['gc_gateweights']
289 conv_gate = conv_gate + tf.nn.conv1d(global_condition_batch,
290 weights_gc_gate,
291 stride=1,
292 padding="SAME",
293 name="gc_gate")
294
295 if self.use_biases:
296 filter_bias = variables['filter_bias']
297 gate_bias = variables['gate_bias']
298 conv_filter = tf.add(conv_filter, filter_bias)
299 conv_gate = tf.add(conv_gate, gate_bias)
300
301 out = tf.tanh(conv_filter) * tf.sigmoid(conv_gate)
302

Callers 1

_create_networkMethod · 0.95

Calls 1

causal_convFunction · 0.85

Tested by

no test coverage detected