Test jit method with force_recompile=True.
()
| 229 | |
| 230 | |
| 231 | def test_executable_jit_force_recompile(): |
| 232 | """Test jit method with force_recompile=True.""" |
| 233 | # Create target and build |
| 234 | target = tvm.target.Target("c") |
| 235 | lib = tvm.tirx.build(MyModule, target=target) |
| 236 | |
| 237 | # Create an executable |
| 238 | executable = Executable(lib) |
| 239 | |
| 240 | # First jit call |
| 241 | jitted_mod1 = executable.jit() |
| 242 | |
| 243 | # Second jit call without force_recompile should return the same module |
| 244 | jitted_mod2 = executable.jit() |
| 245 | assert jitted_mod1 is jitted_mod2 |
| 246 | |
| 247 | # Third jit call with force_recompile should return a new module |
| 248 | jitted_mod3 = executable.jit(force_recompile=True) |
| 249 | assert jitted_mod3 is not jitted_mod1 |
| 250 | |
| 251 | # Test the function works |
| 252 | a = tvm.runtime.tensor(np.array([1.0] * 10, dtype="float32")) |
| 253 | b = tvm.runtime.tensor(np.array([2.0] * 10, dtype="float32")) |
| 254 | c = tvm.runtime.tensor(np.array([0.0] * 10, dtype="float32")) |
| 255 | |
| 256 | jitted_mod3["add"](a, b, c) |
| 257 | |
| 258 | # Check results |
| 259 | tvm.testing.assert_allclose(c.numpy(), np.array([3.0] * 10, dtype="float32")) |
| 260 | |
| 261 | |
| 262 | if __name__ == "__main__": |
nothing calls this directly
no test coverage detected
searching dependent graphs…