(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True)
| 212 | return llh_vals.mean().item() |
| 213 | |
| 214 | def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True): |
| 215 | dim_y = 2 if self.dim_y == 1 else self.dim_y |
| 216 | y_eval = ds.expand_front(tc.arange(dim_y, device=x.device), ds.tcsize_div(x.shape, self.shape_x)) |
| 217 | x_eval = ds.expand_middle(x, (dim_y,), -len(self.shape_x)) |
| 218 | obs_xy = ds.edic({'x': x_eval, 'y': y_eval}) |
| 219 | if self.q_sv1x is not None: |
| 220 | logits = (self.q_sv1x.expect(lambda dc: self.p_y1s.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp) |
| 221 | ) if n_mc_q == 0 else ( |
| 222 | self.q_sv1x.expect(lambda dc: self.p_y1s.logp(dc,dc), |
| 223 | obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q) |
| 224 | ) |
| 225 | else: |
| 226 | vwei_p_y1s_logp = lambda dc: self.p_sv.logp(dc,dc) - self.pt_sv.logp(dc,dc) + self.p_y1s.logp(dc,dc) |
| 227 | logits = (self.qt_sv1x.expect(vwei_p_y1s_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp) |
| 228 | ) if n_mc_q == 0 else ( |
| 229 | self.qt_sv1x.expect(vwei_p_y1s_logp, obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q) |
| 230 | ) |
| 231 | return (logits[..., 1] - logits[..., 0]).squeeze(-1) if self.dim_y == 1 else logits |
| 232 | |
| 233 | def generate(self, shape_mc: tc.Size=tc.Size(), mode: str="src") -> tuple: |
| 234 | if mode == "src": smp_sv = self.p_sv.draw(shape_mc) |
nothing calls this directly
no test coverage detected