(block)
| 18309 | this._curr_loop = null; |
| 18310 | } |
| 18311 | assignExitContinuations(block) { |
| 18312 | for (const n of block.nodes()) { |
| 18313 | switch (n.kind()) { |
| 18314 | case 'prim::If': { |
| 18315 | this.assignExitContinuations(n.blocks().at(0)); |
| 18316 | this.assignExitContinuations(n.blocks().at(1)); |
| 18317 | break; |
| 18318 | } |
| 18319 | case 'prim::Closure': { |
| 18320 | const closure_block = new torch._C.LoopContinuations(); |
| 18321 | closure_block.run(n.blocks().at(0)); |
| 18322 | break; |
| 18323 | } |
| 18324 | case 'prim::Loop': { |
| 18325 | const prev_loop = this._curr_loop; |
| 18326 | this._curr_loop = n; |
| 18327 | this.assignExitContinuations(n.blocks().at(0)); |
| 18328 | this._curr_loop = prev_loop; |
| 18329 | break; |
| 18330 | } |
| 18331 | case 'prim::ContinueStmt': { |
| 18332 | const loop_continuation = this._graph.create('prim::LoopContinuation', 0).insertAfter(n); |
| 18333 | const header_block = loop_continuation.addBlock(); |
| 18334 | const [, pre_header] = this._curr_loop.blocks(); |
| 18335 | header_block.cloneFrom(pre_header, (v) => v); |
| 18336 | this.InlineBlockBeforeNode(n, header_block); |
| 18337 | loop_continuation.addInput(header_block.outputs()[0]); |
| 18338 | loop_continuation.eraseBlock(0); |
| 18339 | this.addLoopCarriedOutputs(loop_continuation); |
| 18340 | n.destroy(); |
| 18341 | break; |
| 18342 | } |
| 18343 | case 'prim::BreakStmt': { |
| 18344 | const loop_exit = this._graph.create('prim::LoopContinuation', 0).insertAfter(n); |
| 18345 | loop_exit.addInput(this._false_val); |
| 18346 | this.addLoopCarriedOutputs(loop_exit); |
| 18347 | n.destroy(); |
| 18348 | break; |
| 18349 | } |
| 18350 | default: { |
| 18351 | break; |
| 18352 | } |
| 18353 | } |
| 18354 | } |
| 18355 | } |
| 18356 | run(...args) { |
| 18357 | if (args.length === 1 && args[0] instanceof torch.Graph) { |
| 18358 | const [graph] = args; |
no test coverage detected