Rewrite rules are applied until convergence In this test, both the `RewriteAdd` and `RewriteMultiply` patterns must be applied in order to produce the expected output. However, the `RewriteMultiply` pattern relies on the expression produced by the `RewriteAdd` pass.
()
| 326 | |
| 327 | |
| 328 | def test_recursive_rewrite_rules(): |
| 329 | """Rewrite rules are applied until convergence |
| 330 | |
| 331 | In this test, both the `RewriteAdd` and `RewriteMultiply` patterns |
| 332 | must be applied in order to produce the expected output. However, |
| 333 | the `RewriteMultiply` pattern relies on the expression produced by |
| 334 | the `RewriteAdd` pass. |
| 335 | |
| 336 | """ |
| 337 | |
| 338 | @R.rewriter |
| 339 | class RewriteAdd: |
| 340 | @R.function |
| 341 | def pattern(A: R.Tensor([16], "float32")): |
| 342 | return A + A |
| 343 | |
| 344 | @R.function |
| 345 | def replacement(A: R.Tensor([16], "float32")): |
| 346 | return A * R.const(2.0, "float32") |
| 347 | |
| 348 | @R.rewriter |
| 349 | class RewriteMultiply: |
| 350 | @R.function |
| 351 | def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): |
| 352 | C = A * B |
| 353 | return C |
| 354 | |
| 355 | @R.function |
| 356 | def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): |
| 357 | C = R.call_pure_packed( |
| 358 | "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") |
| 359 | ) |
| 360 | return C |
| 361 | |
| 362 | @R.function(private=True) |
| 363 | def before(A: R.Tensor([16], "float32")): |
| 364 | B = A + A |
| 365 | return B |
| 366 | |
| 367 | @R.function(private=True) |
| 368 | def expected(A: R.Tensor([16], "float32")): |
| 369 | B = R.call_pure_packed( |
| 370 | "my_optimized_mul_impl", |
| 371 | A, |
| 372 | R.const(2.0, "float32"), |
| 373 | sinfo_args=R.Tensor([16], "float32"), |
| 374 | ) |
| 375 | return B |
| 376 | |
| 377 | rewriter = RewriteAdd | RewriteMultiply |
| 378 | |
| 379 | after = rewriter(before) |
| 380 | tvm.ir.assert_structural_equal(expected, after) |
| 381 | |
| 382 | |
| 383 | def test_rewrite_of_arbitrary_dtype(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…