(title, p_z, p_x1z)
| 16 | ds.Distr.default_device = device |
| 17 | |
| 18 | def test_fun(title, p_z, p_x1z): |
| 19 | print(title) |
| 20 | print("p_z:", p_z.names, p_z.parents) |
| 21 | print("p_x1z:", p_x1z.names, p_x1z.parents) |
| 22 | p_zx = p_z * p_x1z |
| 23 | print("p_zx:", p_zx.names, p_zx.parents) |
| 24 | smp_z = p_z.draw(shape_bat) |
| 25 | print("sample shape z:", smp_z['z'].shape) |
| 26 | smp_x1z = p_x1z.draw((), smp_z) |
| 27 | print("sample shape x:", smp_x1z['x'].shape) |
| 28 | print("logp match:", tc.allclose( |
| 29 | p_z.logp(smp_z) + p_x1z.logp(smp_x1z, smp_z), |
| 30 | p_zx.logp(smp_z|smp_x1z) )) |
| 31 | smp_zx = p_zx.draw(shape_bat) |
| 32 | print("sample shape z:", smp_zx['z'].shape) |
| 33 | print("sample shape x:", smp_zx['x'].shape) |
| 34 | print("logp match:", tc.allclose( |
| 35 | p_z.logp(smp_zx) + p_x1z.logp(smp_zx, smp_zx), |
| 36 | p_zx.logp(smp_zx) )) |
| 37 | print("logp_cartes shape:", p_x1z.logp_cartes(smp_x1z, smp_z).shape) |
| 38 | print() |
| 39 | ds.Distr.clear() |
| 40 | |
| 41 | # Normal |
| 42 | ndim_x = len(shape_x) |
no test coverage detected