Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor.
(
node: NodeProto, threshold: int, value_constant_of_shape: float
)
| 29 | |
| 30 | |
| 31 | def _replace_constant( |
| 32 | node: NodeProto, threshold: int, value_constant_of_shape: float |
| 33 | ) -> list[NodeProto]: |
| 34 | """Replaces a Constant node with a large tensor (with more than threshold elements) by a sequence of nodes that produces a dummy constant of same shape as original tensor.""" |
| 35 | if node.op_type != "Constant": |
| 36 | raise TypeError(f"Node type must be 'Constant' not {node.op_type!r}.") |
| 37 | for att in node.attribute: |
| 38 | if att.name == "sparse_value": |
| 39 | raise NotImplementedError( |
| 40 | f"This feature is not yet implemented for a sparse constant " |
| 41 | f"(node name={node.name!r})." |
| 42 | ) |
| 43 | if att.name == "value": |
| 44 | value = att.t |
| 45 | new_name = f"{value.name}__SHAPE" |
| 46 | dims = value.dims |
| 47 | size = np.prod(dims, dtype=np.int64) |
| 48 | if size <= threshold: |
| 49 | return [node] |
| 50 | init = from_array(np.array(list(dims), dtype=np.int64), name=new_name) |
| 51 | dtype = tensor_dtype_to_np_dtype(value.data_type) |
| 52 | node_shape = make_node( |
| 53 | "Constant", |
| 54 | [], |
| 55 | [new_name], |
| 56 | value=init, |
| 57 | ) |
| 58 | new_node = make_node( |
| 59 | "ConstantOfShape", |
| 60 | [new_name], |
| 61 | node.output, |
| 62 | value=from_array(np.array([value_constant_of_shape], dtype=dtype)), |
| 63 | ) |
| 64 | return [node_shape, new_node] |
| 65 | raise NotImplementedError( |
| 66 | f"Replacement of constant with attribute {att.name!r}" |
| 67 | ) |
| 68 | return [node] |
| 69 | |
| 70 | |
| 71 | def _replace_constant_of_shape_with_range( |
no test coverage detected
searching dependent graphs…