| 194 | |
| 195 | # Utilities |
| 196 | def llh(self, x: tc.Tensor, y: tc.LongTensor=None, n_mc_marg: int=64, use_q: bool=True, mode: str="src") -> float: |
| 197 | if mode == "src": |
| 198 | p_joint = self.p_svx |
| 199 | q_cond = self.q_sv1x if self.q_sv1x else self.qt_sv1x |
| 200 | elif mode == "tgt": |
| 201 | p_joint = self.pt_svx |
| 202 | q_cond = self.qt_sv1x if self.qt_sv1x else self.q_sv1x |
| 203 | else: raise ValueError(f"unknown `mode` '{mode}'") |
| 204 | if not use_q: |
| 205 | if y is None: llh_vals = p_joint.marg({'x'}, n_mc_marg).logp({'x': x}) |
| 206 | else: llh_vals = (p_joint * self.p_y1s).marg({'x', 'y'}, n_mc_marg).logp({'x': x, 'y': y}) |
| 207 | else: |
| 208 | if y is None: llh_vals = q_cond.expect(lambda dc: p_joint.logp(dc) - q_cond.logp(dc,dc), |
| 209 | {'x': x}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) |
| 210 | else: llh_vals = q_cond.expect(lambda dc: (p_joint * self.p_y1s).logp(dc) - q_cond.logp(dc,dc), |
| 211 | {'x': x, 'y': y}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) |
| 212 | return llh_vals.mean().item() |
| 213 | |
| 214 | def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True): |
| 215 | dim_y = 2 if self.dim_y == 1 else self.dim_y |