For supervised VAE with structure x <- z -> y. Observations are supervised (x,y) pairs. For unsupervised observations of x data, use `elbo(p_zx, q_z1x, obs_x)` as VAE z -> x.
(p_zx: Distr, p_y1z: Distr, q_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True)
| 12 | __email__ = "changliu@microsoft.com" |
| 13 | |
| 14 | def elbo_z2xy(p_zx: Distr, p_y1z: Distr, q_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor: |
| 15 | """ For supervised VAE with structure x <- z -> y. |
| 16 | Observations are supervised (x,y) pairs. |
| 17 | For unsupervised observations of x data, use `elbo(p_zx, q_z1x, obs_x)` as VAE z -> x. """ |
| 18 | if n_mc == 0: |
| 19 | q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) |
| 20 | if hasattr(q_z1x, "entropy"): # No difference for Gaussian |
| 21 | expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc), obs_xy, 0, repar) + q_z1x.entropy(obs_xy) |
| 22 | else: |
| 23 | expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc) - q_z1x.logp(dc,dc), obs_xy, 0, repar) |
| 24 | return wlogpi * q_y1x_logpval + expc_val |
| 25 | else: |
| 26 | q_y1x_pval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp(), obs_xy, n_mc, repar) |
| 27 | expc_val = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp() * (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)), |
| 28 | obs_xy, n_mc, repar) |
| 29 | return wlogpi * q_y1x_pval.log() + expc_val / q_y1x_pval |
| 30 | # q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, n_mc, repar, |
| 31 | # reducefn=tc.logsumexp) - math.log(n_mc) |
| 32 | # expc_logval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc) + (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)).log(), |
| 33 | # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc) |
| 34 | # return wlogpi * q_y1x_logpval + (expc_logval - q_y1x_logpval).exp() |
| 35 | |
| 36 | def elbo_z2xy_twist(pt_zx: Distr, p_y1z: Distr, p_z: Distr, pt_z: Distr, qt_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor: |
| 37 | vwei_p_y1z_logp = lambda dc: p_z.logp(dc,dc) - pt_z.logp(dc,dc) + p_y1z.logp(dc,dc) # z, y: |