MCPcopy
hub / github.com/apache/tvm / test_get_valid_counts_e2e

Function test_get_valid_counts_e2e

tests/python/relax/test_op_vision.py:575–620  ·  view source on GitHub ↗

Run get_valid_counts through legalization and compare with the numpy reference.

()

Source from the content-addressed store, hash-verified

573
574@pytest.mark.skipif(not env.has_llvm(), reason="need llvm")
575def test_get_valid_counts_e2e():
576 """Run get_valid_counts through legalization and compare with the numpy reference."""
577
578 @tvm.script.ir_module
579 class GVCModule:
580 @R.function
581 def main(
582 data: R.Tensor((2, 5, 6), "float32"),
583 ) -> R.Tuple(
584 R.Tensor((2,), "int32"),
585 R.Tensor((2, 5, 6), "float32"),
586 R.Tensor((2, 5), "int32"),
587 ):
588 return R.vision.get_valid_counts(data, score_threshold=0.5, id_index=0, score_index=1)
589
590 data_np = np.array(
591 [
592 [
593 [0.0, 0.95, 0.0, 0.0, 1.0, 1.0],
594 [1.0, 0.30, 0.0, 0.0, 1.0, 1.0],
595 [-1.0, 0.90, 0.0, 0.0, 1.0, 1.0],
596 [2.0, 0.75, 2.0, 2.0, 3.0, 3.0],
597 [1.0, 0.10, 4.0, 4.0, 5.0, 5.0],
598 ],
599 [
600 [0.0, 0.55, 0.0, 0.0, 1.0, 1.0],
601 [1.0, 0.80, 1.0, 1.0, 2.0, 2.0],
602 [2.0, 0.40, 2.0, 2.0, 3.0, 3.0],
603 [3.0, 0.60, 3.0, 3.0, 4.0, 4.0],
604 [-1.0, 0.95, 5.0, 5.0, 6.0, 6.0],
605 ],
606 ],
607 dtype="float32",
608 )
609 ref_valid_count, ref_out_data, ref_out_indices = tvm.topi.testing.get_valid_counts_python(
610 data_np, score_threshold=0.5, id_index=0, score_index=1
611 )
612
613 mod = LegalizeOps()(GVCModule)
614 exe = tvm.compile(mod, target="llvm")
615 vm = relax.VirtualMachine(exe, tvm.cpu())
616 result = vm["main"](tvm.runtime.tensor(data_np, tvm.cpu()))
617
618 tvm.testing.assert_allclose(result[0].numpy(), ref_valid_count)
619 tvm.testing.assert_allclose(result[1].numpy(), ref_out_data)
620 tvm.testing.assert_allclose(result[2].numpy(), ref_out_indices)
621
622
623def _prepare_nms_inputs(raw_data: np.ndarray):

Callers

nothing calls this directly

Calls 4

LegalizeOpsFunction · 0.90
numpyMethod · 0.80
compileMethod · 0.45
cpuMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…