Return the output shape of the given Pytorch node.
(torch_node)
| 46 | |
| 47 | |
| 48 | def get_shape(torch_node): |
| 49 | """Return the output shape of the given Pytorch node.""" |
| 50 | # Extract node output shape from the node string representation |
| 51 | # This is a hack because there doesn't seem to be an official way to do it. |
| 52 | # See my quesiton in the PyTorch forum: |
| 53 | # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 |
| 54 | # TODO: find a better way to extract output shape |
| 55 | # TODO: Assuming the node has one output. Update if we encounter a multi-output node. |
| 56 | m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) |
| 57 | if m: |
| 58 | shape = m.group(1) |
| 59 | shape = shape.split(",") |
| 60 | shape = tuple(map(int, shape)) |
| 61 | else: |
| 62 | shape = None |
| 63 | return shape |
| 64 | |
| 65 | |
| 66 | def import_graph(hl_graph, model, args, input_names=None, verbose=False): |