MCPcopy Index your code
hub / github.com/apache/tvm / test_save_load_float8

Function test_save_load_float8

tests/python/contrib/test_tvmjs.py:45–60  ·  view source on GitHub ↗
(dtype)

Source from the content-addressed store, hash-verified

43
44
45def test_save_load_float8(dtype):
46 if "float8" in dtype or "bfloat16" in dtype:
47 ml_dtypes = pytest.importorskip("ml_dtypes")
48 np_dtype = np.dtype(getattr(ml_dtypes, dtype))
49 else:
50 np_dtype = np.dtype(dtype)
51
52 arr = np.arange(16, dtype=np_dtype)
53
54 with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir:
55 tvmjs.dump_tensor_cache({"arr": arr}, temp_dir)
56 cache, _ = tvmjs.load_tensor_cache(temp_dir, tvm.cpu())
57
58 after_roundtrip = cache["arr"].numpy()
59
60 np.testing.assert_array_equal(arr, after_roundtrip)
61
62
63if __name__ == "__main__":

Callers

nothing calls this directly

Calls 4

numpyMethod · 0.80
dtypeMethod · 0.45
arangeMethod · 0.45
cpuMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…