Computes the probability distribution of the next sample based on all samples in the input waveform. If you want to generate audio by feeding the output of the network back as an input, see predict_proba_incremental for a faster alternative.
(self, waveform, global_condition=None, name='wavenet')
| 566 | return embedding |
| 567 | |
| 568 | def predict_proba(self, waveform, global_condition=None, name='wavenet'): |
| 569 | '''Computes the probability distribution of the next sample based on |
| 570 | all samples in the input waveform. |
| 571 | If you want to generate audio by feeding the output of the network back |
| 572 | as an input, see predict_proba_incremental for a faster alternative.''' |
| 573 | with tf.name_scope(name): |
| 574 | if self.scalar_input: |
| 575 | encoded = tf.cast(waveform, tf.float32) |
| 576 | encoded = tf.reshape(encoded, [-1, 1]) |
| 577 | else: |
| 578 | encoded = self._one_hot(waveform) |
| 579 | |
| 580 | gc_embedding = self._embed_gc(global_condition) |
| 581 | raw_output = self._create_network(encoded, gc_embedding) |
| 582 | out = tf.reshape(raw_output, [-1, self.quantization_channels]) |
| 583 | # Cast to float64 to avoid bug in TensorFlow |
| 584 | proba = tf.cast( |
| 585 | tf.nn.softmax(tf.cast(out, tf.float64)), tf.float32) |
| 586 | last = tf.slice( |
| 587 | proba, |
| 588 | [tf.shape(proba)[0] - 1, 0], |
| 589 | [1, self.quantization_channels]) |
| 590 | return tf.reshape(last, [-1]) |
| 591 | |
| 592 | def predict_proba_incremental(self, waveform, global_condition=None, |
| 593 | name='wavenet'): |