MCPcopy
hub / github.com/jindongwang/transferlearning / expect

Method expect

code/deep/CSG/distr/base.py:84–92  ·  view source on GitHub ↗
(self, fn, conds: edic=edic(), n_mc: int=10, repar: bool=True, reducefn = tc.mean)

Source from the content-addressed store, hash-verified

82 raise NotImplementedError
83
84 def expect(self, fn, conds: edic=edic(), n_mc: int=10, repar: bool=True, reducefn = tc.mean) -> tc.Tensor:
85 # [shape_bat] -> [shape_bat]
86 if n_mc == 0:
87 vals = self.mean(conds, 0, repar)
88 return fn(conds|vals)
89 elif n_mc > 0:
90 vals = self.draw(tc.Size((n_mc,)), conds, repar)
91 return reducefn(fn( edicify(conds)[0].expand_front((n_mc,)) | vals ), dim=0)
92 else: raise ValueError(f"For {self}, negative `n_mc` {n_mc} encountered")
93
94 def rdraw(self, shape_mc: tc.Size=tc.Size(), conds: edic=edic()) -> edic: # vals
95 # shape_mc, [shape_bat, shape_cond] -> [shape_mc, shape_bat, shape_var]

Callers 14

_entropyMethod · 0.80
_logpMethod · 0.80
elboFunction · 0.80
elbo_z2xyFunction · 0.80
elbo_z2xy_twistFunction · 0.80
elbo_zy2xFunction · 0.80
llhMethod · 0.80
logit_y1x_srcMethod · 0.80
llhMethod · 0.80
logit_y1x_srcMethod · 0.80
elbo_z2xyFunction · 0.80
elbo_z2xy_twistFunction · 0.80

Calls 5

meanMethod · 0.95
drawMethod · 0.95
edicClass · 0.85
edicifyFunction · 0.85
expand_frontMethod · 0.80

Tested by

no test coverage detected