MCPcopy Index your code
hub / github.com/apache/tvm / check_region_bound

Function check_region_bound

tests/python/arith/test_arith_intset.py:140–208  ·  view source on GitHub ↗

Helper to check region bound estimation. Parameters ---------- expect_region: dict The keys are of form (begin, end) or PrimExpr as a single point. The values are expected estimated region or region dict on different bindings. var_dom: dict Map var to iterat

(expect_region, var_dom, mode, predicate=None)

Source from the content-addressed store, hash-verified

138
139
140def check_region_bound(expect_region, var_dom, mode, predicate=None):
141 """Helper to check region bound estimation.
142
143 Parameters
144 ----------
145 expect_region: dict
146 The keys are of form (begin, end) or PrimExpr as a single point. The values are
147 expected estimated region or region dict on different bindings.
148
149 var_dom: dict
150 Map var to iteration domain range.
151
152 mode: str
153 Specify "lowerbound", "upperbound" or else use strict bound estimation.
154
155 predicate: PrimExpr
156 Extra predicate, defaults to True.
157 """
158 if predicate is None:
159 predicate = tvm.tirx.IntImm("bool", 1)
160 region = []
161 expect = []
162 for k, v in expect_region.items():
163 if not isinstance(k, tuple | list):
164 k = (k, k + 1)
165 region.append(tvm.ir.Range.from_min_extent(k[0], Analyzer().simplify(k[1] - k[0])))
166 expect.append(v)
167 if mode == "lowerbound":
168 result = tvm.arith.estimate_region_lower_bound(
169 region=region, var_dom=var_dom, predicate=predicate
170 )
171 elif mode == "upperbound":
172 result = tvm.arith.estimate_region_upper_bound(
173 region=region, var_dom=var_dom, predicate=predicate
174 )
175 else:
176 result = tvm.arith.estimate_region_strict_bound(
177 region=region, var_dom=var_dom, predicate=predicate
178 )
179 if result is None:
180 assert all([_ is None for _ in expect])
181 return
182 assert len(result) == len(expect)
183 for intset, expect_desc in zip(result, expect):
184 if isinstance(expect_desc, dict):
185 # check range on different free var bindings
186 for binding in expect_desc:
187 analyzer = Analyzer()
188 for k, v in binding:
189 analyzer.bind(k, v)
190 expect_begin, expect_end = expect_desc[binding]
191 result_begin = analyzer.simplify(intset.min_value, 3)
192 result_end = analyzer.simplify(intset.max_value + 1, 3)
193 assert analyzer.can_prove_equal(result_begin - expect_begin, 0), (
194 f"{result_begin} vs {expect_begin}"
195 )
196 assert analyzer.can_prove_equal(result_end - expect_end, 0), (
197 f"{result_end} vs {expect_end}"

Calls 8

simplifyMethod · 0.95
bindMethod · 0.95
can_prove_equalMethod · 0.95
AnalyzerClass · 0.90
from_min_extentMethod · 0.80
allFunction · 0.50
itemsMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…