Returns embedding for global condition. :param global_condition: Either ID of global condition for tf.nn.embedding_lookup or actual embedding. The latter is experimental. :return: Embedding or None
(self, global_condition)
| 529 | return encoded |
| 530 | |
| 531 | def _embed_gc(self, global_condition): |
| 532 | '''Returns embedding for global condition. |
| 533 | :param global_condition: Either ID of global condition for |
| 534 | tf.nn.embedding_lookup or actual embedding. The latter is |
| 535 | experimental. |
| 536 | :return: Embedding or None |
| 537 | ''' |
| 538 | embedding = None |
| 539 | if self.global_condition_cardinality is not None: |
| 540 | # Only lookup the embedding if the global condition is presented |
| 541 | # as an integer of mutually-exclusive categories ... |
| 542 | embedding_table = self.variables['embeddings']['gc_embedding'] |
| 543 | embedding = tf.nn.embedding_lookup(embedding_table, |
| 544 | global_condition) |
| 545 | elif global_condition is not None: |
| 546 | # ... else the global_condition (if any) is already provided |
| 547 | # as an embedding. |
| 548 | |
| 549 | # In this case, the number of global_embedding channels must be |
| 550 | # equal to the the last dimension of the global_condition tensor. |
| 551 | gc_batch_rank = len(global_condition.get_shape()) |
| 552 | dims_match = (global_condition.get_shape()[gc_batch_rank - 1] == |
| 553 | self.global_condition_channels) |
| 554 | if not dims_match: |
| 555 | raise ValueError('Shape of global_condition {} does not' |
| 556 | ' match global_condition_channels {}.'. |
| 557 | format(global_condition.get_shape(), |
| 558 | self.global_condition_channels)) |
| 559 | embedding = global_condition |
| 560 | |
| 561 | if embedding is not None: |
| 562 | embedding = tf.reshape( |
| 563 | embedding, |
| 564 | [self.batch_size, 1, self.global_condition_channels]) |
| 565 | |
| 566 | return embedding |
| 567 | |
| 568 | def predict_proba(self, waveform, global_condition=None, name='wavenet'): |
| 569 | '''Computes the probability distribution of the next sample based on |
no outgoing calls
no test coverage detected