MCPcopy
hub / github.com/onnx/onnx / _replace_constant_of_shape_with_range

Function _replace_constant_of_shape_with_range

onnx/tools/replace_constants.py:71–154  ·  view source on GitHub ↗

Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors. The function is not recursive. The recursivity is done by *replace_initializer_by_constant_of_shape*.

(
    onx: GraphProto | FunctionProto,
)

Source from the content-addressed store, hash-verified

69
70
71def _replace_constant_of_shape_with_range(
72 onx: GraphProto | FunctionProto,
73) -> GraphProto | FunctionProto:
74 """Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors.
75
76 The function is not recursive. The recursivity is done by
77 *replace_initializer_by_constant_of_shape*.
78 """
79 if isinstance(onx, GraphProto):
80 nodes = list(onx.node)
81 elif isinstance(onx, FunctionProto):
82 nodes = list(onx.node)
83 else:
84 raise TypeError(f"Not implemented for type {type(onx)}.")
85
86 existing_names = set()
87 for node in nodes:
88 existing_names |= set(node.input)
89 existing_names |= set(node.output)
90
91 def _find_name(prefix):
92 if prefix not in existing_names:
93 existing_names.add(prefix)
94 return prefix
95 i = 2
96 while True:
97 name = f"{prefix}_{i}"
98 if name not in existing_names:
99 existing_names.add(name)
100 return name
101 i += 1
102 # The function should never go through that line.
103 raise RuntimeError("The function should never go through that line.")
104
105 cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0)
106 cst1 = make_node("Constant", [], [_find_name("one")], value_int=1)
107 update = {}
108 for inode, node in enumerate(nodes):
109 if node.op_type != "ConstantOfShape":
110 continue
111 shape = node.input[0]
112
113 n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")])
114 a = make_node(
115 "Range",
116 [cst0.output[0], n.output[0], cst1.output[0]],
117 [_find_name(f"{shape}_RANGE")],
118 )
119 if len(node.attribute) == 1:
120 to = node.attribute[0].t.data_type
121 else:
122 to = TensorProto.FLOAT
123 ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to)
124 cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to)
125 d = make_node(
126 "Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")]
127 )
128 resh = make_node("Reshape", [d.output[0], shape], node.output)

Calls 4

make_nodeFunction · 0.90
make_graphFunction · 0.90
make_functionFunction · 0.90
_find_nameFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…