MCPcopy Index your code
hub / github.com/waleedka/hiddenlayer / get_shape

Function get_shape

hiddenlayer/pytorch_builder.py:48–63  ·  view source on GitHub ↗

Return the output shape of the given Pytorch node.

(torch_node)

Source from the content-addressed store, hash-verified

46
47
48def get_shape(torch_node):
49 """Return the output shape of the given Pytorch node."""
50 # Extract node output shape from the node string representation
51 # This is a hack because there doesn't seem to be an official way to do it.
52 # See my quesiton in the PyTorch forum:
53 # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2
54 # TODO: find a better way to extract output shape
55 # TODO: Assuming the node has one output. Update if we encounter a multi-output node.
56 m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs())))
57 if m:
58 shape = m.group(1)
59 shape = shape.split(",")
60 shape = tuple(map(int, shape))
61 else:
62 shape = None
63 return shape
64
65
66def import_graph(hl_graph, model, args, input_names=None, verbose=False):

Callers 1

import_graphFunction · 0.85

Calls 1

matchMethod · 0.45

Tested by

no test coverage detected