(block)
| 7690 | return changed; |
| 7691 | } |
| 7692 | optimizeBlock(block) { |
| 7693 | let changed = false; |
| 7694 | for (const node of block.nodes()) { |
| 7695 | for (const sub_block of node.blocks()) { |
| 7696 | changed = changed || this.optimizeBlock(sub_block); |
| 7697 | } |
| 7698 | if (node.kind() !== 'prim::Constant') { |
| 7699 | const guard = new torch._C.WithInsertPoint(node); |
| 7700 | for (const output of node.outputs()) { |
| 7701 | if (output.type() instanceof torch.NoneType) { |
| 7702 | output.replaceAllUsesWith(this._graph.insertConstant(new torch._C.IValue())); |
| 7703 | changed = true; |
| 7704 | } |
| 7705 | } |
| 7706 | guard.dispose(); |
| 7707 | } |
| 7708 | if (node.kind() === 'prim::If') { |
| 7709 | // throw new python.Error('Not implemented.'); |
| 7710 | /* |
| 7711 | const n = new torch._C.IfView(node); |
| 7712 | // this handles redundant short circuits like "x and True" or "x or |
| 7713 | // False" |
| 7714 | for (const auto i : c10::irange(n.outputs().length)) { |
| 7715 | if (n.outputs().at(i).type() != torch.BoolType.get()) { |
| 7716 | continue; |
| 7717 | } |
| 7718 | const true_val = constant_as<bool>(n.thenOutputs().at(i)).value_or(false); |
| 7719 | const false_val = constant_as<bool>(n.elseOutputs().at(i)).value_or(true); |
| 7720 | if (true_val && !false_val) { |
| 7721 | n.outputs().at(i).replaceAllUsesWith(n.cond()); |
| 7722 | changed = true; |
| 7723 | } |
| 7724 | } |
| 7725 | for (let i = 0; i < n.outputs().length; ++i) { |
| 7726 | const inputs_non_optional = !n.thenOutputs().at(i).type().cast<OptionalType>() && !n.elseOutputs().at(i).type().cast<OptionalType>(); |
| 7727 | const output_optional = n.outputs()[i].type(); |
| 7728 | if (inputs_non_optional && output_optional instanceof torch.OptionalType) { |
| 7729 | const unif = torch._c.unifyTypes(n.thenOutputs().at(i).type(), n.elseOutputs().at(i).type()) |
| 7730 | if (unif) { |
| 7731 | n.outputs()[i].setType(unif); |
| 7732 | changed = true; |
| 7733 | } |
| 7734 | } |
| 7735 | } |
| 7736 | */ |
| 7737 | } else if (node.kind() === 'aten::__is__' || node.kind() === 'aten::__isnot__') { |
| 7738 | torch._C.AT_ASSERT(node.inputs().length === 2); |
| 7739 | for (const check_none_index of [0, 1]) { |
| 7740 | const input_must_be_none = node.inputs()[check_none_index].mustBeNone(); |
| 7741 | const other_must_not_be_none = node.inputs().at(1 - check_none_index).mustNotBeNone(); |
| 7742 | if (input_must_be_none && other_must_not_be_none) { |
| 7743 | const guard = new torch._C.WithInsertPoint(node); |
| 7744 | const output = node.owningGraph().insertConstant(node.kind() === 'aten::__isnot__'); |
| 7745 | node.output().replaceAllUsesWith(output); |
| 7746 | changed = true; |
| 7747 | guard.dispose(); |
| 7748 | } |
| 7749 | } |
no test coverage detected