MCPcopy
hub / github.com/jindongwang/transferlearning / elbo_z2xy

Function elbo_z2xy

code/deep/CSG/methods/xdistr.py:14–34  ·  view source on GitHub ↗

For supervised VAE with structure x <- z -> y. Observations are supervised (x,y) pairs. For unsupervised observations of x data, use `elbo(p_zx, q_z1x, obs_x)` as VAE z -> x.

(p_zx: Distr, p_y1z: Distr, q_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True)

Source from the content-addressed store, hash-verified

12__email__ = "changliu@microsoft.com"
13
14def elbo_z2xy(p_zx: Distr, p_y1z: Distr, q_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor:
15 """ For supervised VAE with structure x <- z -> y.
16 Observations are supervised (x,y) pairs.
17 For unsupervised observations of x data, use `elbo(p_zx, q_z1x, obs_x)` as VAE z -> x. """
18 if n_mc == 0:
19 q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp)
20 if hasattr(q_z1x, "entropy"): # No difference for Gaussian
21 expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc), obs_xy, 0, repar) + q_z1x.entropy(obs_xy)
22 else:
23 expc_val = q_z1x.expect(lambda dc: p_zx.logp(dc,dc) - q_z1x.logp(dc,dc), obs_xy, 0, repar)
24 return wlogpi * q_y1x_logpval + expc_val
25 else:
26 q_y1x_pval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp(), obs_xy, n_mc, repar)
27 expc_val = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc).exp() * (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)),
28 obs_xy, n_mc, repar)
29 return wlogpi * q_y1x_pval.log() + expc_val / q_y1x_pval
30 # q_y1x_logpval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc), obs_xy, n_mc, repar,
31 # reducefn=tc.logsumexp) - math.log(n_mc)
32 # expc_logval = q_z1x.expect(lambda dc: p_y1z.logp(dc,dc) + (p_zx.logp(dc,dc) - q_z1x.logp(dc,dc)).log(),
33 # obs_xy, n_mc, repar, reducefn=tc.logsumexp) - math.log(n_mc)
34 # return wlogpi * q_y1x_logpval + (expc_logval - q_y1x_logpval).exp()
35
36def elbo_z2xy_twist(pt_zx: Distr, p_y1z: Distr, p_z: Distr, pt_z: Distr, qt_z1x: Distr, obs_xy: edic, n_mc: int=0, wlogpi: float=1., repar: bool=True) -> tc.Tensor:
37 vwei_p_y1z_logp = lambda dc: p_z.logp(dc,dc) - pt_z.logp(dc,dc) + p_y1z.logp(dc,dc) # z, y:

Callers

nothing calls this directly

Calls 4

expectMethod · 0.80
logMethod · 0.80
logpMethod · 0.45
entropyMethod · 0.45

Tested by

no test coverage detected