MCPcopy
hub / github.com/ibab/tensorflow-wavenet / _embed_gc

Method _embed_gc

wavenet/model.py:531–566  ·  view source on GitHub ↗

Returns embedding for global condition. :param global_condition: Either ID of global condition for tf.nn.embedding_lookup or actual embedding. The latter is experimental. :return: Embedding or None

(self, global_condition)

Source from the content-addressed store, hash-verified

529 return encoded
530
531 def _embed_gc(self, global_condition):
532 '''Returns embedding for global condition.
533 :param global_condition: Either ID of global condition for
534 tf.nn.embedding_lookup or actual embedding. The latter is
535 experimental.
536 :return: Embedding or None
537 '''
538 embedding = None
539 if self.global_condition_cardinality is not None:
540 # Only lookup the embedding if the global condition is presented
541 # as an integer of mutually-exclusive categories ...
542 embedding_table = self.variables['embeddings']['gc_embedding']
543 embedding = tf.nn.embedding_lookup(embedding_table,
544 global_condition)
545 elif global_condition is not None:
546 # ... else the global_condition (if any) is already provided
547 # as an embedding.
548
549 # In this case, the number of global_embedding channels must be
550 # equal to the the last dimension of the global_condition tensor.
551 gc_batch_rank = len(global_condition.get_shape())
552 dims_match = (global_condition.get_shape()[gc_batch_rank - 1] ==
553 self.global_condition_channels)
554 if not dims_match:
555 raise ValueError('Shape of global_condition {} does not'
556 ' match global_condition_channels {}.'.
557 format(global_condition.get_shape(),
558 self.global_condition_channels))
559 embedding = global_condition
560
561 if embedding is not None:
562 embedding = tf.reshape(
563 embedding,
564 [self.batch_size, 1, self.global_condition_channels])
565
566 return embedding
567
568 def predict_proba(self, waveform, global_condition=None, name='wavenet'):
569 '''Computes the probability distribution of the next sample based on

Callers 3

predict_probaMethod · 0.95
lossMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected