| 404 | |
| 405 | |
| 406 | def test_inline_prim_func(): |
| 407 | # Error: inline prim_func is disallowed in Relax IR |
| 408 | x = rx.Var("x", R.Tensor([], "int32")) |
| 409 | y = rx.Var("y", R.Tensor([], "int32")) |
| 410 | new_func = rx.Function( |
| 411 | [], |
| 412 | rx.SeqExpr( |
| 413 | [ |
| 414 | rx.BindingBlock( |
| 415 | [ |
| 416 | rx.VarBinding( |
| 417 | var=x, |
| 418 | value=tirx.PrimFunc([], tirx.Evaluate(0)), |
| 419 | ), |
| 420 | rx.VarBinding( |
| 421 | var=y, |
| 422 | value=rx.Call( |
| 423 | op=tvm.ir.Op.get("relax.call_tir"), |
| 424 | args=[ |
| 425 | rx.GlobalVar("GlobalVar0"), |
| 426 | rx.Tuple([x, tirx.PrimFunc([], tirx.Evaluate(0))]), |
| 427 | rx.ShapeExpr([]), |
| 428 | ], |
| 429 | ), |
| 430 | ), |
| 431 | ] |
| 432 | ) |
| 433 | ], |
| 434 | y, |
| 435 | ), |
| 436 | R.Tensor(ndim=0, dtype="int32"), |
| 437 | ).with_attr("global_symbol", "foo") |
| 438 | new_mod = tvm.IRModule.from_expr(new_func) |
| 439 | assert not rx.analysis.check_well_formed(new_mod, check_struct_info=False) |
| 440 | |
| 441 | |
| 442 | def test_ANF(): |