Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors. The function is not recursive. The recursivity is done by *replace_initializer_by_constant_of_shape*.
(
onx: GraphProto | FunctionProto,
)
| 69 | |
| 70 | |
| 71 | def _replace_constant_of_shape_with_range( |
| 72 | onx: GraphProto | FunctionProto, |
| 73 | ) -> GraphProto | FunctionProto: |
| 74 | """Replaces all *ConstantOfShape* by node *Range* to avoid constant tensors. |
| 75 | |
| 76 | The function is not recursive. The recursivity is done by |
| 77 | *replace_initializer_by_constant_of_shape*. |
| 78 | """ |
| 79 | if isinstance(onx, GraphProto): |
| 80 | nodes = list(onx.node) |
| 81 | elif isinstance(onx, FunctionProto): |
| 82 | nodes = list(onx.node) |
| 83 | else: |
| 84 | raise TypeError(f"Not implemented for type {type(onx)}.") |
| 85 | |
| 86 | existing_names = set() |
| 87 | for node in nodes: |
| 88 | existing_names |= set(node.input) |
| 89 | existing_names |= set(node.output) |
| 90 | |
| 91 | def _find_name(prefix): |
| 92 | if prefix not in existing_names: |
| 93 | existing_names.add(prefix) |
| 94 | return prefix |
| 95 | i = 2 |
| 96 | while True: |
| 97 | name = f"{prefix}_{i}" |
| 98 | if name not in existing_names: |
| 99 | existing_names.add(name) |
| 100 | return name |
| 101 | i += 1 |
| 102 | # The function should never go through that line. |
| 103 | raise RuntimeError("The function should never go through that line.") |
| 104 | |
| 105 | cst0 = make_node("Constant", [], [_find_name("zero")], value_int=0) |
| 106 | cst1 = make_node("Constant", [], [_find_name("one")], value_int=1) |
| 107 | update = {} |
| 108 | for inode, node in enumerate(nodes): |
| 109 | if node.op_type != "ConstantOfShape": |
| 110 | continue |
| 111 | shape = node.input[0] |
| 112 | |
| 113 | n = make_node("ReduceProd", [shape], [_find_name(f"{shape}_N")]) |
| 114 | a = make_node( |
| 115 | "Range", |
| 116 | [cst0.output[0], n.output[0], cst1.output[0]], |
| 117 | [_find_name(f"{shape}_RANGE")], |
| 118 | ) |
| 119 | if len(node.attribute) == 1: |
| 120 | to = node.attribute[0].t.data_type |
| 121 | else: |
| 122 | to = TensorProto.FLOAT |
| 123 | ac = make_node("Cast", [a.output[0]], [_find_name(f"{shape}_RANGEf")], to=to) |
| 124 | cl = make_node("Cast", [n.output[0]], [_find_name(f"{shape}_Nf")], to=to) |
| 125 | d = make_node( |
| 126 | "Div", [ac.output[0], cl.output[0]], [_find_name(f"{shape}_FLAT")] |
| 127 | ) |
| 128 | resh = make_node("Reshape", [d.output[0], shape], node.output) |
no test coverage detected
searching dependent graphs…