MCPcopy
hub / github.com/hyperopt/hyperopt / replace_repeat_stochastic

Function replace_repeat_stochastic

hyperopt/vectorize.py:162–204  ·  view source on GitHub ↗
(expr, return_memo=False)

Source from the content-addressed store, hash-verified

160
161
162def replace_repeat_stochastic(expr, return_memo=False):
163 nodes = dfs(expr)
164 memo = {}
165 for ii, orig in enumerate(nodes):
166 if orig.name == "idxs_map" and orig.pos_args[1]._obj in stoch:
167 # -- this is an idxs_map of a random draw of distribution `dist`
168 idxs = orig.pos_args[0]
169 dist = orig.pos_args[1]._obj
170
171 def foo(arg):
172 # -- each argument is an idxs, vals pair
173 assert arg.name == "pos_args"
174 assert len(arg.pos_args) == 2
175 arg_vals = arg.pos_args[1]
176 if arg_vals.name == "asarray" and arg_vals.inputs()[0].name == "repeat":
177 # -- draws are iid, so forget about
178 # repeating the distribution parameters
179 repeated_thing = arg_vals.inputs()[0].inputs()[1]
180 return repeated_thing
181 else:
182 if arg.pos_args[0] is idxs:
183 return arg_vals
184 # -- arg.pos_args[0] is a superset of idxs
185 # TODO: slice out correct elements using
186 # idxs_take, but more importantly - test this case.
187 raise NotImplementedError()
188
189 new_pos_args = [foo(arg) for arg in orig.pos_args[2:]]
190 new_named_args = [[aname, foo(arg)] for aname, arg in orig.named_args]
191 vnode = Apply(dist, new_pos_args, new_named_args, None)
192 n_times = scope.len(idxs)
193 if "size" in dict(vnode.named_args):
194 raise NotImplementedError("random node already has size")
195 vnode.named_args.append(["size", n_times])
196 # -- loop over all nodes that *use* this one, and change them
197 for client in nodes[ii + 1 :]:
198 client.replace_input(orig, vnode)
199 if expr is orig:
200 expr = vnode
201 memo[orig] = vnode
202 if return_memo:
203 return expr, memo
204 return expr
205
206
207class VectorizeHelper:

Callers 3

test_vectorize_trivialFunction · 0.90
test_vectorize_simpleFunction · 0.90
test_vectorize_config0Function · 0.90

Calls 5

dfsFunction · 0.85
fooFunction · 0.85
ApplyClass · 0.85
lenMethod · 0.80
replace_inputMethod · 0.45

Tested by 3

test_vectorize_trivialFunction · 0.72
test_vectorize_simpleFunction · 0.72
test_vectorize_config0Function · 0.72