Check that bool_expr holds given the condition cond for every value of free variables from vranges. For example, ``2x > 4y`` solves to ``x > 2y`` given ``x in (0, 10)`` and ``y in (0, 10)``. Here bool_expr is ``x > 2y``, vranges is ``{x: (0, 10), y: (0, 10)}``, cond is ``2x > 4y``.
(bool_expr, vranges, cond=None)
| 273 | |
| 274 | |
| 275 | def check_bool_expr_is_true(bool_expr, vranges, cond=None): |
| 276 | """Check that bool_expr holds given the condition cond |
| 277 | for every value of free variables from vranges. |
| 278 | |
| 279 | For example, ``2x > 4y`` solves to ``x > 2y`` given ``x in (0, 10)`` |
| 280 | and ``y in (0, 10)``. Here bool_expr is ``x > 2y``, |
| 281 | vranges is ``{x: (0, 10), y: (0, 10)}``, cond is ``2x > 4y``. |
| 282 | We create iterations to check:: |
| 283 | |
| 284 | for x in range(10): |
| 285 | for y in range(10): |
| 286 | assert !(2x > 4y) || (x > 2y) |
| 287 | |
| 288 | Parameters |
| 289 | ---------- |
| 290 | bool_expr : tvm.ir.PrimExpr |
| 291 | Boolean expression to check |
| 292 | vranges: Dict[tvm.tirx.expr.Var, tvm.ir.Range] |
| 293 | Free variables and their ranges |
| 294 | cond: tvm.ir.PrimExpr |
| 295 | extra conditions needs to be satisfied. |
| 296 | """ |
| 297 | if cond is not None: |
| 298 | bool_expr = tvm.te.any(tvm.tirx.Not(cond), bool_expr) |
| 299 | |
| 300 | def _run_expr(expr, vranges): |
| 301 | """Evaluate expr for every value of free variables |
| 302 | given by vranges and return the tensor of results. |
| 303 | """ |
| 304 | |
| 305 | def _compute_body(*us): |
| 306 | vmap = {v: u + r.min for (v, r), u in zip(vranges.items(), us)} |
| 307 | return tvm.tirx.stmt_functor.substitute(expr, vmap) |
| 308 | |
| 309 | A = tvm.te.compute([r.extent.value for v, r in vranges.items()], _compute_body) |
| 310 | args = [tvm.runtime.empty(A.shape, A.dtype)] |
| 311 | mod = tvm.compile(tvm.IRModule.from_expr(tvm.te.create_prim_func([A]))) |
| 312 | mod(*args) |
| 313 | return args[0].numpy() |
| 314 | |
| 315 | res = _run_expr(bool_expr, vranges) |
| 316 | if not np.all(res): |
| 317 | indices = list(np.argwhere(res == 0)[0]) |
| 318 | counterex = [(str(v), i + r.min) for (v, r), i in zip(vranges.items(), indices)] |
| 319 | counterex = sorted(counterex, key=lambda x: x[0]) |
| 320 | counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) |
| 321 | ana = tvm.arith.Analyzer() |
| 322 | raise AssertionError( |
| 323 | f"Expression {ana.simplify(bool_expr)}\nis not true on {vranges}\n" |
| 324 | f"Counterexample: {counterex}" |
| 325 | ) |
| 326 | |
| 327 | |
| 328 | def check_int_constraints_trans_consistency(constraints_trans, vranges=None): |