MCPcopy
hub / github.com/jindongwang/transferlearning / logit_y1x_src

Method logit_y1x_src

code/deep/CSG/methods/semvar.py:214–231  ·  view source on GitHub ↗
(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True)

Source from the content-addressed store, hash-verified

212 return llh_vals.mean().item()
213
214 def logit_y1x_src(self, x: tc.Tensor, n_mc_q: int=0, repar: bool=True):
215 dim_y = 2 if self.dim_y == 1 else self.dim_y
216 y_eval = ds.expand_front(tc.arange(dim_y, device=x.device), ds.tcsize_div(x.shape, self.shape_x))
217 x_eval = ds.expand_middle(x, (dim_y,), -len(self.shape_x))
218 obs_xy = ds.edic({'x': x_eval, 'y': y_eval})
219 if self.q_sv1x is not None:
220 logits = (self.q_sv1x.expect(lambda dc: self.p_y1s.logp(dc,dc), obs_xy, 0, repar) #, reducefn=tc.logsumexp)
221 ) if n_mc_q == 0 else (
222 self.q_sv1x.expect(lambda dc: self.p_y1s.logp(dc,dc),
223 obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q)
224 )
225 else:
226 vwei_p_y1s_logp = lambda dc: self.p_sv.logp(dc,dc) - self.pt_sv.logp(dc,dc) + self.p_y1s.logp(dc,dc)
227 logits = (self.qt_sv1x.expect(vwei_p_y1s_logp, obs_xy, 0, repar) #, reducefn=tc.logsumexp)
228 ) if n_mc_q == 0 else (
229 self.qt_sv1x.expect(vwei_p_y1s_logp, obs_xy, n_mc_q, repar, reducefn=tc.logsumexp) - math.log(n_mc_q)
230 )
231 return (logits[..., 1] - logits[..., 0]).squeeze(-1) if self.dim_y == 1 else logits
232
233 def generate(self, shape_mc: tc.Size=tc.Size(), mode: str="src") -> tuple:
234 if mode == "src": smp_sv = self.p_sv.draw(shape_mc)

Callers

nothing calls this directly

Calls 4

expand_frontMethod · 0.80
expectMethod · 0.80
logMethod · 0.80
logpMethod · 0.45

Tested by

no test coverage detected