This recursive procedure should be called on an output-node.
(self, node, wanted_idxs)
| 256 | assert [idxs, vals] == take.pos_args[:2] |
| 257 | |
| 258 | def build_idxs_vals(self, node, wanted_idxs): |
| 259 | """ |
| 260 | This recursive procedure should be called on an output-node. |
| 261 | """ |
| 262 | checkpoint_asserts = False |
| 263 | |
| 264 | def checkpoint(): |
| 265 | if checkpoint_asserts: |
| 266 | self.assert_integrity_idxs_take() |
| 267 | if node in self.idxs_memo: |
| 268 | toposort(self.idxs_memo[node]) |
| 269 | if node in self.take_memo: |
| 270 | for take in self.take_memo[node]: |
| 271 | toposort(take) |
| 272 | |
| 273 | checkpoint() |
| 274 | |
| 275 | # wanted_idxs are fixed, whereas idxs_memo |
| 276 | # is full of unions, that can grow in subsequent recursive |
| 277 | # calls to build_idxs_vals with node as argument. |
| 278 | assert wanted_idxs != self.idxs_memo.get(node) |
| 279 | |
| 280 | # -- easy exit case |
| 281 | if node.name == "hyperopt_param": |
| 282 | # -- ignore, not vectorizing |
| 283 | return self.build_idxs_vals(node.arg["obj"], wanted_idxs) |
| 284 | |
| 285 | # -- easy exit case |
| 286 | elif node.name == "hyperopt_result": |
| 287 | # -- ignore, not vectorizing |
| 288 | return self.build_idxs_vals(node.arg["obj"], wanted_idxs) |
| 289 | |
| 290 | # -- literal case: always take from universal set |
| 291 | elif node.name == "literal": |
| 292 | if node in self.idxs_memo: |
| 293 | all_idxs, all_vals = self.take_memo[node][0].pos_args[:2] |
| 294 | wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs) |
| 295 | self.take_memo[node].append(wanted_vals) |
| 296 | checkpoint() |
| 297 | else: |
| 298 | # -- initialize idxs_memo to full set |
| 299 | all_idxs = self.expr_idxs |
| 300 | n_times = scope.len(all_idxs) |
| 301 | # -- put array_union into graph for consistency, though it is |
| 302 | # not necessary |
| 303 | all_idxs = scope.array_union(all_idxs) |
| 304 | self.idxs_memo[node] = all_idxs |
| 305 | all_vals = scope.asarray(scope.repeat(n_times, node)) |
| 306 | wanted_vals = scope.idxs_take(all_idxs, all_vals, wanted_idxs) |
| 307 | assert node not in self.take_memo |
| 308 | self.take_memo[node] = [wanted_vals] |
| 309 | checkpoint() |
| 310 | return wanted_vals |
| 311 | |
| 312 | # -- switch case: complicated |
| 313 | elif node.name == "switch": |
| 314 | if node in self.idxs_memo and wanted_idxs in self.idxs_memo[node].pos_args: |
| 315 | # -- phew, easy case |