MCPcopy
hub / github.com/onnx/onnx / _assert_inferred

Method _assert_inferred

onnx/test/shape_inference_test.py:157–176  ·  view source on GitHub ↗
(
        self,
        graph_or_model: GraphProto | ModelProto,
        vis: list[ValueInfoProto],
        **kwargs: Any,
    )

Source from the content-addressed store, hash-verified

155 return inferred_model
156
157 def _assert_inferred(
158 self,
159 graph_or_model: GraphProto | ModelProto,
160 vis: list[ValueInfoProto],
161 **kwargs: Any,
162 ) -> None:
163 graph = (
164 graph_or_model
165 if isinstance(graph_or_model, GraphProto)
166 else graph_or_model.graph
167 )
168 names_in_vis = {x.name for x in vis}
169 vis = [x for x in graph.value_info if x.name not in names_in_vis] + vis
170 inferred_model = self._inferred(graph_or_model, **kwargs)
171 inferred_vis = list(inferred_model.graph.value_info)
172 vis = sorted(vis, key=lambda x: x.name)
173 inferred_vis = sorted(inferred_vis, key=lambda x: x.name)
174 assert len(vis) == len(inferred_vis)
175 for v, inferred_v in zip(vis, inferred_vis, strict=True):
176 self._compare_value_infos(v.type, inferred_v.type)
177
178 def _compare_value_infos(
179 self, vi_type: TypeProto, inferred_vi_type: TypeProto

Callers 15

_identity_propMethod · 0.80
test_transposeMethod · 0.80
test_transpose_scalarMethod · 0.80
test_castMethod · 0.80
test_cast_likeMethod · 0.80
test_bitcast_scalarMethod · 0.80
test_bitcast_1dMethod · 0.80

Calls 2

_inferredMethod · 0.95
_compare_value_infosMethod · 0.95

Tested by

no test coverage detected