Replaces all fill value of all nodes *ConstantOfShape*.
(
onx: GraphProto | FunctionProto, value_constant_of_shape: float
)
| 155 | |
| 156 | |
| 157 | def _replace_constant_of_shape_value( |
| 158 | onx: GraphProto | FunctionProto, value_constant_of_shape: float |
| 159 | ) -> GraphProto | FunctionProto: |
| 160 | """Replaces all fill value of all nodes *ConstantOfShape*.""" |
| 161 | if isinstance(onx, GraphProto): |
| 162 | nodes = list(onx.node) |
| 163 | elif isinstance(onx, FunctionProto): |
| 164 | nodes = list(onx.node) |
| 165 | else: |
| 166 | raise TypeError(f"Not implemented for type {type(onx)}.") |
| 167 | |
| 168 | existing_names = set() |
| 169 | for node in nodes: |
| 170 | existing_names |= set(node.input) |
| 171 | existing_names |= set(node.output) |
| 172 | |
| 173 | update = {} |
| 174 | for inode, node in enumerate(nodes): |
| 175 | if node.op_type != "ConstantOfShape": |
| 176 | continue |
| 177 | tensor = node.attribute[0].t |
| 178 | new_tensor = make_tensor( |
| 179 | tensor.name, tensor.data_type, [1], [value_constant_of_shape] |
| 180 | ) |
| 181 | new_node = make_node("ConstantOfShape", node.input, node.output) |
| 182 | att = make_attribute(node.attribute[0].name, value=new_tensor) |
| 183 | new_node.attribute.append(att) |
| 184 | update[inode] = new_node |
| 185 | |
| 186 | for inode, up in update.items(): |
| 187 | nodes[inode] = up |
| 188 | |
| 189 | if isinstance(onx, GraphProto): |
| 190 | return make_graph( |
| 191 | nodes, |
| 192 | onx.name, |
| 193 | onx.input, |
| 194 | onx.output, |
| 195 | initializer=onx.initializer, |
| 196 | sparse_initializer=onx.sparse_initializer, |
| 197 | ) |
| 198 | if isinstance(onx, FunctionProto): |
| 199 | return make_function( |
| 200 | onx.domain, |
| 201 | onx.name, |
| 202 | onx.input, |
| 203 | onx.output, |
| 204 | nodes, |
| 205 | opset_imports=onx.opset_import, |
| 206 | ) |
| 207 | raise TypeError(f"Not implemented for type {type(onx)}.") |
| 208 | |
| 209 | |
| 210 | def replace_initializer_by_constant_of_shape( # noqa: PLR0911 |
no test coverage detected
searching dependent graphs…