MCPcopy Index your code
hub / github.com/apache/tvm / test_global_info_vdevice

Function test_global_info_vdevice

tests/python/relax/test_tvmscript_parser.py:315–362  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

313
314
315def test_global_info_vdevice():
316 vdevices = [
317 VDevice("llvm"),
318 VDevice("cuda", 0),
319 VDevice({"kind": "cuda", "arch": "sm_80"}, 0),
320 VDevice("metal", 0, "global"),
321 ]
322
323 @I.ir_module(s_tir=True)
324 class TestModule:
325 I.module_attrs({"attr": 10})
326 I.module_global_infos(
327 {
328 "vdevice": [
329 I.vdevice("llvm"),
330 I.vdevice("cuda", 0),
331 I.vdevice({"kind": "cuda", "arch": "sm_80"}, 0),
332 I.vdevice("metal", 0, "global"),
333 ]
334 }
335 )
336
337 @T.prim_func(private=True, s_tir=True)
338 def tir_func(
339 x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
340 y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
341 ):
342 T.func_attr({"tirx.noalias": True})
343 for i, j in T.grid(T.int64(128), T.int64(128)):
344 with T.sblock():
345 vi, vj = T.axis.remap("SS", [i, j])
346 y[vi, vj] = x[vi, vj] + 1.0
347
348 @R.function
349 def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"):
350 cls = TestModule
351 gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), dtype="float32"))
352 return gv0
353
354 x = relax.Var("x", R.Tensor((128, 128), "float32"))
355 bb = relax.BlockBuilder()
356 with bb.function("foo", (x,)):
357 out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
358 bb.emit_func_output(out)
359 mod = bb.get()
360 mod.update_global_info("vdevice", vdevices)
361 mod = mod.with_attr("attr", 10)
362 _check(TestModule, mod)
363
364
365def test_relax_tensor_op():

Callers

nothing calls this directly

Calls 9

functionMethod · 0.95
emit_teMethod · 0.95
emit_func_outputMethod · 0.95
getMethod · 0.95
VDeviceClass · 0.90
TensorMethod · 0.80
update_global_infoMethod · 0.80
_checkFunction · 0.70
with_attrMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…