Run get_valid_counts through legalization and compare with the numpy reference.
()
| 573 | |
| 574 | @pytest.mark.skipif(not env.has_llvm(), reason="need llvm") |
| 575 | def test_get_valid_counts_e2e(): |
| 576 | """Run get_valid_counts through legalization and compare with the numpy reference.""" |
| 577 | |
| 578 | @tvm.script.ir_module |
| 579 | class GVCModule: |
| 580 | @R.function |
| 581 | def main( |
| 582 | data: R.Tensor((2, 5, 6), "float32"), |
| 583 | ) -> R.Tuple( |
| 584 | R.Tensor((2,), "int32"), |
| 585 | R.Tensor((2, 5, 6), "float32"), |
| 586 | R.Tensor((2, 5), "int32"), |
| 587 | ): |
| 588 | return R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1) |
| 589 | |
| 590 | data_np = np.array( |
| 591 | [ |
| 592 | [ |
| 593 | [0.0, 0.95, 0.0, 0.0, 1.0, 1.0], |
| 594 | [1.0, 0.30, 0.0, 0.0, 1.0, 1.0], |
| 595 | [-1.0, 0.90, 0.0, 0.0, 1.0, 1.0], |
| 596 | [2.0, 0.75, 2.0, 2.0, 3.0, 3.0], |
| 597 | [1.0, 0.10, 4.0, 4.0, 5.0, 5.0], |
| 598 | ], |
| 599 | [ |
| 600 | [0.0, 0.55, 0.0, 0.0, 1.0, 1.0], |
| 601 | [1.0, 0.80, 1.0, 1.0, 2.0, 2.0], |
| 602 | [2.0, 0.40, 2.0, 2.0, 3.0, 3.0], |
| 603 | [3.0, 0.60, 3.0, 3.0, 4.0, 4.0], |
| 604 | [-1.0, 0.95, 5.0, 5.0, 6.0, 6.0], |
| 605 | ], |
| 606 | ], |
| 607 | dtype="float32", |
| 608 | ) |
| 609 | ref_valid_count, ref_out_data, ref_out_indices = tvm.topi.testing.get_valid_counts_python( |
| 610 | data_np, score_threshold=0.5, id_index=0, score_index=1 |
| 611 | ) |
| 612 | |
| 613 | mod = LegalizeOps()(GVCModule) |
| 614 | exe = tvm.compile(mod, target="llvm") |
| 615 | vm = relax.VirtualMachine(exe, tvm.cpu()) |
| 616 | result = vm["main"](tvm.runtime.tensor(data_np, tvm.cpu())) |
| 617 | |
| 618 | tvm.testing.assert_allclose(result[0].numpy(), ref_valid_count) |
| 619 | tvm.testing.assert_allclose(result[1].numpy(), ref_out_data) |
| 620 | tvm.testing.assert_allclose(result[2].numpy(), ref_out_indices) |
| 621 | |
| 622 | |
| 623 | def _prepare_nms_inputs(raw_data: np.ndarray): |
nothing calls this directly
no test coverage detected
searching dependent graphs…