Computes the probability distribution of the next sample incrementally, based on a single sample and all previously passed samples.
(self, waveform, global_condition=None,
name='wavenet')
| 590 | return tf.reshape(last, [-1]) |
| 591 | |
| 592 | def predict_proba_incremental(self, waveform, global_condition=None, |
| 593 | name='wavenet'): |
| 594 | '''Computes the probability distribution of the next sample |
| 595 | incrementally, based on a single sample and all previously passed |
| 596 | samples.''' |
| 597 | if self.filter_width > 2: |
| 598 | raise NotImplementedError("Incremental generation does not " |
| 599 | "support filter_width > 2.") |
| 600 | if self.scalar_input: |
| 601 | raise NotImplementedError("Incremental generation does not " |
| 602 | "support scalar input yet.") |
| 603 | with tf.name_scope(name): |
| 604 | encoded = tf.one_hot(waveform, self.quantization_channels) |
| 605 | encoded = tf.reshape(encoded, [-1, self.quantization_channels]) |
| 606 | gc_embedding = self._embed_gc(global_condition) |
| 607 | raw_output = self._create_generator(encoded, gc_embedding) |
| 608 | out = tf.reshape(raw_output, [-1, self.quantization_channels]) |
| 609 | proba = tf.cast( |
| 610 | tf.nn.softmax(tf.cast(out, tf.float64)), tf.float32) |
| 611 | last = tf.slice( |
| 612 | proba, |
| 613 | [tf.shape(proba)[0] - 1, 0], |
| 614 | [1, self.quantization_channels]) |
| 615 | return tf.reshape(last, [-1]) |
| 616 | |
| 617 | def loss(self, |
| 618 | input_batch, |