This function creates all variables used by the network. This allows us to share them between multiple calls to the loss function and generation function.
(self)
| 125 | return receptive_field |
| 126 | |
| 127 | def _create_variables(self): |
| 128 | '''This function creates all variables used by the network. |
| 129 | This allows us to share them between multiple calls to the loss |
| 130 | function and generation function.''' |
| 131 | |
| 132 | var = dict() |
| 133 | |
| 134 | with tf.variable_scope('wavenet'): |
| 135 | if self.global_condition_cardinality is not None: |
| 136 | # We only look up the embedding if we are conditioning on a |
| 137 | # set of mutually-exclusive categories. We can also condition |
| 138 | # on an already-embedded dense vector, in which case it's |
| 139 | # given to us and we don't need to do the embedding lookup. |
| 140 | # Still another alternative is no global condition at all, in |
| 141 | # which case we also don't do a tf.nn.embedding_lookup. |
| 142 | with tf.variable_scope('embeddings'): |
| 143 | layer = dict() |
| 144 | layer['gc_embedding'] = create_embedding_table( |
| 145 | 'gc_embedding', |
| 146 | [self.global_condition_cardinality, |
| 147 | self.global_condition_channels]) |
| 148 | var['embeddings'] = layer |
| 149 | |
| 150 | with tf.variable_scope('causal_layer'): |
| 151 | layer = dict() |
| 152 | if self.scalar_input: |
| 153 | initial_channels = 1 |
| 154 | initial_filter_width = self.initial_filter_width |
| 155 | else: |
| 156 | initial_channels = self.quantization_channels |
| 157 | initial_filter_width = self.filter_width |
| 158 | layer['filter'] = create_variable( |
| 159 | 'filter', |
| 160 | [initial_filter_width, |
| 161 | initial_channels, |
| 162 | self.residual_channels]) |
| 163 | var['causal_layer'] = layer |
| 164 | |
| 165 | var['dilated_stack'] = list() |
| 166 | with tf.variable_scope('dilated_stack'): |
| 167 | for i, dilation in enumerate(self.dilations): |
| 168 | with tf.variable_scope('layer{}'.format(i)): |
| 169 | current = dict() |
| 170 | current['filter'] = create_variable( |
| 171 | 'filter', |
| 172 | [self.filter_width, |
| 173 | self.residual_channels, |
| 174 | self.dilation_channels]) |
| 175 | current['gate'] = create_variable( |
| 176 | 'gate', |
| 177 | [self.filter_width, |
| 178 | self.residual_channels, |
| 179 | self.dilation_channels]) |
| 180 | current['dense'] = create_variable( |
| 181 | 'dense', |
| 182 | [1, |
| 183 | self.dilation_channels, |
| 184 | self.residual_channels]) |
no test coverage detected