(TestClass, expected=None)
| 59 | |
| 60 | |
| 61 | def verify(TestClass, expected=None): |
| 62 | if isinstance(TestClass, type): |
| 63 | cf = TestClass().func.get_concrete_function() |
| 64 | else: |
| 65 | cf = TestClass |
| 66 | mod = _get_mod_from_cfunc(cf) |
| 67 | |
| 68 | if expected: |
| 69 | tvm.ir.assert_structural_equal(mod, expected) |
| 70 | |
| 71 | # Run E2E test only on nightly |
| 72 | if "CI_ENV_NIGHTLY" not in os.environ: |
| 73 | return |
| 74 | |
| 75 | # Inputs |
| 76 | tf_inputs = [] |
| 77 | tvm_inputs = [] |
| 78 | for arg in mod["main"].params: |
| 79 | shape = tuple(shape_val.value for shape_val in arg.struct_info.shape.values) |
| 80 | data = np.random.uniform(0, 1, size=shape).astype(arg.struct_info.dtype) |
| 81 | tvm_inputs.append(data) |
| 82 | tf_inputs.append(tf.constant(data)) |
| 83 | |
| 84 | # TF Run |
| 85 | tf_output = cf(*tf_inputs) |
| 86 | |
| 87 | # TVM Run |
| 88 | tgt = tvm.target.Target("c") |
| 89 | ex = tvm.compile(mod, tgt) |
| 90 | vm = relax.VirtualMachine(ex, tvm.cpu()) |
| 91 | vm.set_input("main", *tvm_inputs) |
| 92 | vm.invoke_stateful("main") |
| 93 | tvm_output = vm.get_outputs("main") |
| 94 | |
| 95 | if isinstance(tf_output, tuple): |
| 96 | for tf_out, tvm_out in zip(tf_output, tvm_output): |
| 97 | np.testing.assert_allclose(tf_out.numpy(), tvm_out.numpy(), rtol=1e-5, atol=1e-5) |
| 98 | else: |
| 99 | np.testing.assert_allclose(tf_output.numpy(), tvm_output.numpy(), rtol=1e-5, atol=1e-5) |
| 100 | |
| 101 | |
| 102 | def _verify_random_with_inputs(cfunc, inputs): |
no test coverage detected
searching dependent graphs…