(orig)
| 120 | |
| 121 | |
| 122 | def vectorize_stochastic(orig): |
| 123 | if orig.name == "idxs_map" and orig.pos_args[1]._obj in stoch: |
| 124 | # -- this is an idxs_map of a random draw of distribution `dist` |
| 125 | idxs = orig.pos_args[0] |
| 126 | dist = orig.pos_args[1]._obj |
| 127 | |
| 128 | def foo(arg): |
| 129 | # -- each argument is an idxs, vals pair |
| 130 | assert arg.name == "pos_args" |
| 131 | assert len(arg.pos_args) == 2 |
| 132 | arg_vals = arg.pos_args[1] |
| 133 | |
| 134 | # XXX: write a pattern-substitution rule for this case |
| 135 | if arg_vals.name == "idxs_take": |
| 136 | if arg_vals.arg["vals"].name == "asarray": |
| 137 | if arg_vals.arg["vals"].inputs()[0].name == "repeat": |
| 138 | # -- draws are iid, so forget about |
| 139 | # repeating the distribution parameters |
| 140 | repeated_thing = arg_vals.arg["vals"].inputs()[0].inputs()[1] |
| 141 | return repeated_thing |
| 142 | if arg.pos_args[0] is idxs: |
| 143 | return arg_vals |
| 144 | else: |
| 145 | # -- arg.pos_args[0] is a superset of idxs |
| 146 | # TODO: slice out correct elements using |
| 147 | # idxs_take, but more importantly - test this case. |
| 148 | raise NotImplementedError() |
| 149 | |
| 150 | new_pos_args = [foo(arg) for arg in orig.pos_args[2:]] |
| 151 | new_named_args = [[aname, foo(arg)] for aname, arg in orig.named_args] |
| 152 | vnode = Apply(dist, new_pos_args, new_named_args, o_len=None) |
| 153 | n_times = scope.len(idxs) |
| 154 | if "size" in dict(vnode.named_args): |
| 155 | raise NotImplementedError("random node already has size") |
| 156 | vnode.named_args.append(["size", n_times]) |
| 157 | return vnode |
| 158 | else: |
| 159 | return orig |
| 160 | |
| 161 | |
| 162 | def replace_repeat_stochastic(expr, return_memo=False): |
no test coverage detected