(
self,
graph_or_model: GraphProto | ModelProto,
vis: list[ValueInfoProto],
**kwargs: Any,
)
| 155 | return inferred_model |
| 156 | |
| 157 | def _assert_inferred( |
| 158 | self, |
| 159 | graph_or_model: GraphProto | ModelProto, |
| 160 | vis: list[ValueInfoProto], |
| 161 | **kwargs: Any, |
| 162 | ) -> None: |
| 163 | graph = ( |
| 164 | graph_or_model |
| 165 | if isinstance(graph_or_model, GraphProto) |
| 166 | else graph_or_model.graph |
| 167 | ) |
| 168 | names_in_vis = {x.name for x in vis} |
| 169 | vis = [x for x in graph.value_info if x.name not in names_in_vis] + vis |
| 170 | inferred_model = self._inferred(graph_or_model, **kwargs) |
| 171 | inferred_vis = list(inferred_model.graph.value_info) |
| 172 | vis = sorted(vis, key=lambda x: x.name) |
| 173 | inferred_vis = sorted(inferred_vis, key=lambda x: x.name) |
| 174 | assert len(vis) == len(inferred_vis) |
| 175 | for v, inferred_v in zip(vis, inferred_vis, strict=True): |
| 176 | self._compare_value_infos(v.type, inferred_v.type) |
| 177 | |
| 178 | def _compare_value_infos( |
| 179 | self, vi_type: TypeProto, inferred_vi_type: TypeProto |
no test coverage detected