| 160 | |
| 161 | |
| 162 | def replace_repeat_stochastic(expr, return_memo=False): |
| 163 | nodes = dfs(expr) |
| 164 | memo = {} |
| 165 | for ii, orig in enumerate(nodes): |
| 166 | if orig.name == "idxs_map" and orig.pos_args[1]._obj in stoch: |
| 167 | # -- this is an idxs_map of a random draw of distribution `dist` |
| 168 | idxs = orig.pos_args[0] |
| 169 | dist = orig.pos_args[1]._obj |
| 170 | |
| 171 | def foo(arg): |
| 172 | # -- each argument is an idxs, vals pair |
| 173 | assert arg.name == "pos_args" |
| 174 | assert len(arg.pos_args) == 2 |
| 175 | arg_vals = arg.pos_args[1] |
| 176 | if arg_vals.name == "asarray" and arg_vals.inputs()[0].name == "repeat": |
| 177 | # -- draws are iid, so forget about |
| 178 | # repeating the distribution parameters |
| 179 | repeated_thing = arg_vals.inputs()[0].inputs()[1] |
| 180 | return repeated_thing |
| 181 | else: |
| 182 | if arg.pos_args[0] is idxs: |
| 183 | return arg_vals |
| 184 | # -- arg.pos_args[0] is a superset of idxs |
| 185 | # TODO: slice out correct elements using |
| 186 | # idxs_take, but more importantly - test this case. |
| 187 | raise NotImplementedError() |
| 188 | |
| 189 | new_pos_args = [foo(arg) for arg in orig.pos_args[2:]] |
| 190 | new_named_args = [[aname, foo(arg)] for aname, arg in orig.named_args] |
| 191 | vnode = Apply(dist, new_pos_args, new_named_args, None) |
| 192 | n_times = scope.len(idxs) |
| 193 | if "size" in dict(vnode.named_args): |
| 194 | raise NotImplementedError("random node already has size") |
| 195 | vnode.named_args.append(["size", n_times]) |
| 196 | # -- loop over all nodes that *use* this one, and change them |
| 197 | for client in nodes[ii + 1 :]: |
| 198 | client.replace_input(orig, vnode) |
| 199 | if expr is orig: |
| 200 | expr = vnode |
| 201 | memo[orig] = vnode |
| 202 | if return_memo: |
| 203 | return expr, memo |
| 204 | return expr |
| 205 | |
| 206 | |
| 207 | class VectorizeHelper: |