| 117 | |
| 118 | # Utilities |
| 119 | def llh(self, x: tc.Tensor, y: tc.LongTensor=None, n_mc_marg: int=64, use_q: bool=True, mode: str="src") -> float: |
| 120 | if mode == "src": |
| 121 | p_joint = self.p_sx |
| 122 | q_cond = self.q_s1x if self.q_s1x else self.qt_s1x |
| 123 | elif mode == "tgt": |
| 124 | p_joint = self.pt_sx |
| 125 | q_cond = self.qt_s1x if self.qt_s1x else self.q_s1x |
| 126 | else: raise ValueError(f"unknown `mode` '{mode}'") |
| 127 | if not use_q: |
| 128 | if y is None: llh_vals = p_joint.marg({'x'}, n_mc_marg).logp({'x': x}) |
| 129 | else: llh_vals = (p_joint * self.p_y1s).marg({'x', 'y'}, n_mc_marg).logp({'x': x, 'y': y}) |
| 130 | else: |
| 131 | if y is None: llh_vals = q_cond.expect(lambda dc: p_joint.logp(dc) - q_cond.logp(dc,dc), |
| 132 | {'x': x}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) |
| 133 | else: llh_vals = q_cond.expect(lambda dc: (p_joint * self.p_y1s).logp(dc) - q_cond.logp(dc,dc), |
| 134 | {'x': x, 'y': y}, n_mc_marg, reducefn=tc.logsumexp) - math.log(n_mc_marg) |
| 135 | return llh_vals.mean().item() |
| 136 | |
| 137 | def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True): |
| 138 | dim_y = 2 if self.dim_y == 1 else self.dim_y |