MCPcopy
hub / github.com/ibab/tensorflow-wavenet / _create_variables

Method _create_variables

wavenet/model.py:127–234  ·  view source on GitHub ↗

This function creates all variables used by the network. This allows us to share them between multiple calls to the loss function and generation function.

(self)

Source from the content-addressed store, hash-verified

125 return receptive_field
126
127 def _create_variables(self):
128 '''This function creates all variables used by the network.
129 This allows us to share them between multiple calls to the loss
130 function and generation function.'''
131
132 var = dict()
133
134 with tf.variable_scope('wavenet'):
135 if self.global_condition_cardinality is not None:
136 # We only look up the embedding if we are conditioning on a
137 # set of mutually-exclusive categories. We can also condition
138 # on an already-embedded dense vector, in which case it's
139 # given to us and we don't need to do the embedding lookup.
140 # Still another alternative is no global condition at all, in
141 # which case we also don't do a tf.nn.embedding_lookup.
142 with tf.variable_scope('embeddings'):
143 layer = dict()
144 layer['gc_embedding'] = create_embedding_table(
145 'gc_embedding',
146 [self.global_condition_cardinality,
147 self.global_condition_channels])
148 var['embeddings'] = layer
149
150 with tf.variable_scope('causal_layer'):
151 layer = dict()
152 if self.scalar_input:
153 initial_channels = 1
154 initial_filter_width = self.initial_filter_width
155 else:
156 initial_channels = self.quantization_channels
157 initial_filter_width = self.filter_width
158 layer['filter'] = create_variable(
159 'filter',
160 [initial_filter_width,
161 initial_channels,
162 self.residual_channels])
163 var['causal_layer'] = layer
164
165 var['dilated_stack'] = list()
166 with tf.variable_scope('dilated_stack'):
167 for i, dilation in enumerate(self.dilations):
168 with tf.variable_scope('layer{}'.format(i)):
169 current = dict()
170 current['filter'] = create_variable(
171 'filter',
172 [self.filter_width,
173 self.residual_channels,
174 self.dilation_channels])
175 current['gate'] = create_variable(
176 'gate',
177 [self.filter_width,
178 self.residual_channels,
179 self.dilation_channels])
180 current['dense'] = create_variable(
181 'dense',
182 [1,
183 self.dilation_channels,
184 self.residual_channels])

Callers 1

__init__Method · 0.95

Calls 3

create_embedding_tableFunction · 0.85
create_variableFunction · 0.85
create_bias_variableFunction · 0.85

Tested by

no test coverage detected