()
| 90 | @pytest.mark.gpu |
| 91 | @pytest.mark.skipif(not env.has_rocm(), reason="need rocm") |
| 92 | def test_batch_matmul(): |
| 93 | if not tvm.get_global_func("tvm.contrib.hipblas.batch_matmul", True): |
| 94 | print("skip because extern function is not available") |
| 95 | return |
| 96 | |
| 97 | verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") |
| 98 | verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") |
| 99 | verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") |
| 100 | verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") |
| 101 | verify_batch_matmul( |
| 102 | (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 |
| 103 | ) |
| 104 | verify_batch_matmul( |
| 105 | (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 |
| 106 | ) |
| 107 | |
| 108 | verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32") |
| 109 | |
| 110 | |
| 111 | if __name__ == "__main__": |
nothing calls this directly
no test coverage detected
searching dependent graphs…