MCPcopy Index your code
hub / github.com/waleedka/hiddenlayer / import_graph

Function import_graph

hiddenlayer/pytorch_builder.py:66–98  ·  view source on GitHub ↗
(hl_graph, model, args, input_names=None, verbose=False)

Source from the content-addressed store, hash-verified

64
65
66def import_graph(hl_graph, model, args, input_names=None, verbose=False):
67 # TODO: add input names to graph
68
69 # Run the Pytorch graph to get a trace and generate a graph from it
70 trace, out = torch.jit.get_trace_graph(model, args)
71 torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
72 torch_graph = trace.graph()
73
74 # Dump list of nodes (DEBUG only)
75 if verbose:
76 dump_pytorch_graph(torch_graph)
77
78 # Loop through nodes and build HL graph
79 for torch_node in torch_graph.nodes():
80 # Op
81 op = torch_node.kind()
82 # Parameters
83 params = {k: torch_node[k] for k in torch_node.attributeNames()}
84 # Inputs/outputs
85 # TODO: inputs = [i.unique() for i in node.inputs()]
86 outputs = [o.unique() for o in torch_node.outputs()]
87 # Get output shape
88 shape = get_shape(torch_node)
89 # Add HL node
90 hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
91 output_shape=shape, params=params)
92 hl_graph.add_node(hl_node)
93 # Add edges
94 for target_torch_node in torch_graph.nodes():
95 target_inputs = [i.unique() for i in target_torch_node.inputs()]
96 if set(outputs) & set(target_inputs):
97 hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
98 return hl_graph

Callers

nothing calls this directly

Calls 6

dump_pytorch_graphFunction · 0.85
get_shapeFunction · 0.85
NodeClass · 0.85
pytorch_idFunction · 0.85
add_nodeMethod · 0.80
add_edge_by_idMethod · 0.80

Tested by

no test coverage detected