(hl_graph, model, args, input_names=None, verbose=False)
| 64 | |
| 65 | |
| 66 | def import_graph(hl_graph, model, args, input_names=None, verbose=False): |
| 67 | # TODO: add input names to graph |
| 68 | |
| 69 | # Run the Pytorch graph to get a trace and generate a graph from it |
| 70 | trace, out = torch.jit.get_trace_graph(model, args) |
| 71 | torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX) |
| 72 | torch_graph = trace.graph() |
| 73 | |
| 74 | # Dump list of nodes (DEBUG only) |
| 75 | if verbose: |
| 76 | dump_pytorch_graph(torch_graph) |
| 77 | |
| 78 | # Loop through nodes and build HL graph |
| 79 | for torch_node in torch_graph.nodes(): |
| 80 | # Op |
| 81 | op = torch_node.kind() |
| 82 | # Parameters |
| 83 | params = {k: torch_node[k] for k in torch_node.attributeNames()} |
| 84 | # Inputs/outputs |
| 85 | # TODO: inputs = [i.unique() for i in node.inputs()] |
| 86 | outputs = [o.unique() for o in torch_node.outputs()] |
| 87 | # Get output shape |
| 88 | shape = get_shape(torch_node) |
| 89 | # Add HL node |
| 90 | hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op, |
| 91 | output_shape=shape, params=params) |
| 92 | hl_graph.add_node(hl_node) |
| 93 | # Add edges |
| 94 | for target_torch_node in torch_graph.nodes(): |
| 95 | target_inputs = [i.unique() for i in target_torch_node.inputs()] |
| 96 | if set(outputs) & set(target_inputs): |
| 97 | hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape) |
| 98 | return hl_graph |
nothing calls this directly
no test coverage detected