MCPcopy Index your code
hub / github.com/apache/tvm / test_recursive_rewrite_rules

Function test_recursive_rewrite_rules

tests/python/relax/test_dataflow_rewriter.py:328–380  ·  view source on GitHub ↗

Rewrite rules are applied until convergence In this test, both the `RewriteAdd` and `RewriteMultiply` patterns must be applied in order to produce the expected output. However, the `RewriteMultiply` pattern relies on the expression produced by the `RewriteAdd` pass.

()

Source from the content-addressed store, hash-verified

326
327
328def test_recursive_rewrite_rules():
329 """Rewrite rules are applied until convergence
330
331 In this test, both the `RewriteAdd` and `RewriteMultiply` patterns
332 must be applied in order to produce the expected output. However,
333 the `RewriteMultiply` pattern relies on the expression produced by
334 the `RewriteAdd` pass.
335
336 """
337
338 @R.rewriter
339 class RewriteAdd:
340 @R.function
341 def pattern(A: R.Tensor([16], "float32")):
342 return A + A
343
344 @R.function
345 def replacement(A: R.Tensor([16], "float32")):
346 return A * R.const(2.0, "float32")
347
348 @R.rewriter
349 class RewriteMultiply:
350 @R.function
351 def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")):
352 C = A * B
353 return C
354
355 @R.function
356 def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")):
357 C = R.call_pure_packed(
358 "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32")
359 )
360 return C
361
362 @R.function(private=True)
363 def before(A: R.Tensor([16], "float32")):
364 B = A + A
365 return B
366
367 @R.function(private=True)
368 def expected(A: R.Tensor([16], "float32")):
369 B = R.call_pure_packed(
370 "my_optimized_mul_impl",
371 A,
372 R.const(2.0, "float32"),
373 sinfo_args=R.Tensor([16], "float32"),
374 )
375 return B
376
377 rewriter = RewriteAdd | RewriteMultiply
378
379 after = rewriter(before)
380 tvm.ir.assert_structural_equal(expected, after)
381
382
383def test_rewrite_of_arbitrary_dtype():

Callers

nothing calls this directly

Calls 1

rewriterFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…