This method clones a posterior inference graph by iterating forward in topological order, and replacing prior random-variables (prior_idxs, prior_vals) with new posterior distributions (post_specs, post_idxs, post_vals) that make use of observations (obs_idxs, obs_vals).
(
specs,
prior_idxs,
prior_vals,
obs_idxs,
obs_vals,
obs_loss_idxs,
obs_loss_vals,
oloss_gamma,
prior_weight,
)
| 652 | |
| 653 | |
| 654 | def build_posterior( |
| 655 | specs, |
| 656 | prior_idxs, |
| 657 | prior_vals, |
| 658 | obs_idxs, |
| 659 | obs_vals, |
| 660 | obs_loss_idxs, |
| 661 | obs_loss_vals, |
| 662 | oloss_gamma, |
| 663 | prior_weight, |
| 664 | ): |
| 665 | """ |
| 666 | This method clones a posterior inference graph by iterating forward in |
| 667 | topological order, and replacing prior random-variables (prior_idxs, prior_vals) |
| 668 | with new posterior distributions (post_specs, post_idxs, post_vals) that make use |
| 669 | of observations (obs_idxs, obs_vals). |
| 670 | |
| 671 | """ |
| 672 | assert all( |
| 673 | isinstance(arg, pyll.Apply) |
| 674 | for arg in [obs_loss_idxs, obs_loss_vals, oloss_gamma] |
| 675 | ) |
| 676 | assert set(prior_idxs.keys()) == set(prior_vals.keys()) |
| 677 | |
| 678 | expr = pyll.as_apply([specs, prior_idxs, prior_vals]) |
| 679 | nodes = pyll.dfs(expr) |
| 680 | |
| 681 | # build the joint posterior distribution as the values in this memo |
| 682 | memo = {} |
| 683 | # map prior RVs to observations |
| 684 | obs_memo = {} |
| 685 | |
| 686 | for nid in prior_vals: |
| 687 | # construct the leading args for each call to adaptive_parzen_sampler |
| 688 | # which will permit the "adaptive parzen samplers" to adapt to the |
| 689 | # correct samples. |
| 690 | obs_below, obs_above = scope.ap_split_trials( |
| 691 | obs_idxs[nid], obs_vals[nid], obs_loss_idxs, obs_loss_vals, oloss_gamma |
| 692 | ) |
| 693 | obs_memo[prior_vals[nid]] = [obs_below, obs_above] |
| 694 | for node in nodes: |
| 695 | if node not in memo: |
| 696 | new_inputs = [memo[arg] for arg in node.inputs()] |
| 697 | if node in obs_memo: |
| 698 | # -- this case corresponds to an observed Random Var |
| 699 | # node.name is a distribution like "normal", "randint", etc. |
| 700 | obs_below, obs_above = obs_memo[node] |
| 701 | aa = [memo[a] for a in node.pos_args] |
| 702 | fn = adaptive_parzen_samplers[node.name] |
| 703 | b_args = [obs_below, prior_weight] + aa |
| 704 | named_args = {kw: memo[arg] for (kw, arg) in node.named_args} |
| 705 | b_post = fn(*b_args, **named_args) |
| 706 | a_args = [obs_above, prior_weight] + aa |
| 707 | a_post = fn(*a_args, **named_args) |
| 708 | |
| 709 | # fn is a function e.g ap_uniform_sampler, ap_normal_sampler, etc |
| 710 | # b_post and a_post are pyll.Apply objects that are |
| 711 | # AST (Abstract Syntax Trees). They create the distribution, |
no test coverage detected