x: [B,T,C] x_mask: [B,T] g: [B,T,C]
(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False)
| 767 | )) |
| 768 | |
| 769 | def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False): |
| 770 | """ |
| 771 | x: [B,T,C] |
| 772 | x_mask: [B,T] |
| 773 | g: [B,T,C] |
| 774 | """ |
| 775 | x = x.transpose(1,2) |
| 776 | x_mask = x_mask.unsqueeze(1) |
| 777 | if g is not None: |
| 778 | g = g.transpose(1,2) |
| 779 | |
| 780 | logdet_tot = 0 |
| 781 | if not reverse: |
| 782 | flows = self.flows |
| 783 | else: |
| 784 | flows = reversed(self.flows) |
| 785 | if return_hiddens: |
| 786 | hs = [] |
| 787 | if self.n_sqz > 1: |
| 788 | x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz) |
| 789 | if g is not None: |
| 790 | g, _ = utils.squeeze(g, x_mask, self.n_sqz) |
| 791 | x_mask = x_mask_ |
| 792 | if self.share_cond_layers and g is not None: |
| 793 | g = self.cond_layer(g) |
| 794 | for f in flows: |
| 795 | x, logdet = f(x, x_mask, g=g, reverse=reverse) |
| 796 | if return_hiddens: |
| 797 | hs.append(x) |
| 798 | logdet_tot += logdet |
| 799 | if self.n_sqz > 1: |
| 800 | x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz) |
| 801 | |
| 802 | x = x.transpose(1,2) |
| 803 | if return_hiddens: |
| 804 | return x, logdet_tot, hs |
| 805 | return x, logdet_tot |
| 806 | |
| 807 | def store_inverse(self): |
| 808 | def remove_weight_norm(m): |