(node)
| 18569 | } |
| 18570 | } |
| 18571 | transformLoop(node) { |
| 18572 | const loop = new torch._C.LoopView(node); |
| 18573 | const body = loop.bodyBlock(); |
| 18574 | const exit_pair = this.transformExits(body); |
| 18575 | if (this.getExitStatus(exit_pair) === 'WONT' || this.getExitStatus(exit_pair) === 'THROWS') { |
| 18576 | return this.constructWontExitPair(); |
| 18577 | } |
| 18578 | const insert = new torch._C.WithInsertPoint(body); |
| 18579 | const new_if = this._graph.insertNode(this._graph.create('prim::If', 0)); |
| 18580 | new_if.addInput(exit_pair.hasExited()); |
| 18581 | new_if.addBlock().registerOutput(this._false_val); |
| 18582 | new_if.addBlock().registerOutput(loop.nextCond()); |
| 18583 | const new_condition = new_if.addOutput().setType(torch.BoolType.get()); |
| 18584 | loop.bodyBlock().eraseOutput(0); |
| 18585 | loop.bodyBlock().insertOutput(0, new_condition); |
| 18586 | node.addInput(this._false_val); |
| 18587 | body.addInput().setType(torch.BoolType.get()); |
| 18588 | body.registerOutput(exit_pair.hasExited()); |
| 18589 | const new_has_exited = node.addOutput().setType(torch.BoolType.get()); |
| 18590 | for (const exit_value of exit_pair.exitValues()) { |
| 18591 | const typ = exit_value.type(); |
| 18592 | node.addInput(this.getUnitValue(typ)); |
| 18593 | node.addOutput().setType(typ); |
| 18594 | body.addInput().setType(typ); |
| 18595 | body.registerOutput(exit_value); |
| 18596 | } |
| 18597 | const exit_vals = node.outputs().slice(node.outputs().length - exit_pair.exitValues().size()); |
| 18598 | const result = new torch._C.ExitPair(new_has_exited, exit_vals); |
| 18599 | insert.dispose(); |
| 18600 | return result; |
| 18601 | } |
| 18602 | calcIfExitStatus(then_status, else_status) { |
| 18603 | if (then_status === 'THROWS') { |
| 18604 | return else_status; |
no test coverage detected