MCPcopy Index your code
hub / github.com/apache/tvm / check_bool_expr_is_true

Function check_bool_expr_is_true

python/tvm/testing/utils.py:275–325  ·  view source on GitHub ↗

Check that bool_expr holds given the condition cond for every value of free variables from vranges. For example, ``2x > 4y`` solves to ``x > 2y`` given ``x in (0, 10)`` and ``y in (0, 10)``. Here bool_expr is ``x > 2y``, vranges is ``{x: (0, 10), y: (0, 10)}``, cond is ``2x > 4y``.

(bool_expr, vranges, cond=None)

Source from the content-addressed store, hash-verified

273
274
275def check_bool_expr_is_true(bool_expr, vranges, cond=None):
276 """Check that bool_expr holds given the condition cond
277 for every value of free variables from vranges.
278
279 For example, ``2x > 4y`` solves to ``x > 2y`` given ``x in (0, 10)``
280 and ``y in (0, 10)``. Here bool_expr is ``x > 2y``,
281 vranges is ``{x: (0, 10), y: (0, 10)}``, cond is ``2x > 4y``.
282 We create iterations to check::
283
284 for x in range(10):
285 for y in range(10):
286 assert !(2x > 4y) || (x > 2y)
287
288 Parameters
289 ----------
290 bool_expr : tvm.ir.PrimExpr
291 Boolean expression to check
292 vranges: Dict[tvm.tirx.expr.Var, tvm.ir.Range]
293 Free variables and their ranges
294 cond: tvm.ir.PrimExpr
295 extra conditions needs to be satisfied.
296 """
297 if cond is not None:
298 bool_expr = tvm.te.any(tvm.tirx.Not(cond), bool_expr)
299
300 def _run_expr(expr, vranges):
301 """Evaluate expr for every value of free variables
302 given by vranges and return the tensor of results.
303 """
304
305 def _compute_body(*us):
306 vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)}
307 return tvm.tirx.stmt_functor.substitute(expr, vmap)
308
309 A = tvm.te.compute([r.extent.value for v, r in vranges.items()], _compute_body)
310 args = [tvm.runtime.empty(A.shape, A.dtype)]
311 mod = tvm.compile(tvm.IRModule.from_expr(tvm.te.create_prim_func([A])))
312 mod(*args)
313 return args[0].numpy()
314
315 res = _run_expr(bool_expr, vranges)
316 if not np.all(res):
317 indices = list(np.argwhere(res == 0)[0])
318 counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)]
319 counterex = sorted(counterex, key=lambda x: x[0])
320 counterex = ", ".join([v + " = " + str(i) for v, i in counterex])
321 ana = tvm.arith.Analyzer()
322 raise AssertionError(
323 f"Expression {ana.simplify(bool_expr)}\nis not true on {vranges}\n"
324 f"Counterexample: {counterex}"
325 )
326
327
328def check_int_constraints_trans_consistency(constraints_trans, vranges=None):

Callers 1

_check_forwardFunction · 0.85

Calls 5

simplifyMethod · 0.95
_run_exprFunction · 0.85
strFunction · 0.85
itemsMethod · 0.45
joinMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…