Convert a pyll expression representing a single trial into a pyll expression representing multiple trials. The resulting multi-trial expression is not meant to be evaluated directly. It is meant to serve as the input to a suggest algo. idxs_memo - node in expr graph -> all ele
| 205 | |
| 206 | |
| 207 | class VectorizeHelper: |
| 208 | """ |
| 209 | Convert a pyll expression representing a single trial into a pyll |
| 210 | expression representing multiple trials. |
| 211 | |
| 212 | The resulting multi-trial expression is not meant to be evaluated |
| 213 | directly. It is meant to serve as the input to a suggest algo. |
| 214 | |
| 215 | idxs_memo - node in expr graph -> all elements we might need for it |
| 216 | take_memo - node in expr graph -> all exprs retrieving computed elements |
| 217 | |
| 218 | """ |
| 219 | |
| 220 | def __init__(self, expr, expr_idxs, build=True): |
| 221 | self.expr = expr |
| 222 | self.expr_idxs = expr_idxs |
| 223 | self.dfs_nodes = dfs(expr) |
| 224 | self.params = {} |
| 225 | for ii, node in enumerate(self.dfs_nodes): |
| 226 | if node.name == "hyperopt_param": |
| 227 | label = node.arg["label"].obj |
| 228 | self.params[label] = node.arg["obj"] |
| 229 | # -- recursive construction |
| 230 | # This makes one term in each idxs, vals memo for every |
| 231 | # directed path through the switches in the graph. |
| 232 | |
| 233 | self.idxs_memo = {} # node -> union, all idxs computed |
| 234 | self.take_memo = {} # node -> list of idxs_take retrieving node vals |
| 235 | self.v_expr = self.build_idxs_vals(expr, expr_idxs) |
| 236 | |
| 237 | # TODO: graph-optimization pass to remove cruft: |
| 238 | # - unions of 1 |
| 239 | # - unions of full sets with their subsets |
| 240 | # - idxs_take that can be merged |
| 241 | |
| 242 | self.assert_integrity_idxs_take() |
| 243 | |
| 244 | def assert_integrity_idxs_take(self): |
| 245 | idxs_memo = self.idxs_memo |
| 246 | take_memo = self.take_memo |
| 247 | after = dfs(self.expr) |
| 248 | assert after == self.dfs_nodes |
| 249 | assert set(idxs_memo.keys()) == set(take_memo.keys()) |
| 250 | for node in idxs_memo: |
| 251 | idxs = idxs_memo[node] |
| 252 | assert idxs.name == "array_union" |
| 253 | vals = take_memo[node][0].pos_args[1] |
| 254 | for take in take_memo[node]: |
| 255 | assert take.name == "idxs_take" |
| 256 | assert [idxs, vals] == take.pos_args[:2] |
| 257 | |
| 258 | def build_idxs_vals(self, node, wanted_idxs): |
| 259 | """ |
| 260 | This recursive procedure should be called on an output-node. |
| 261 | """ |
| 262 | checkpoint_asserts = False |
| 263 | |
| 264 | def checkpoint(): |
no outgoing calls