MCPcopy Index your code
hub / github.com/apache/tvm / compute

Function compute

python/tvm/te/operation.py:59–139  ·  view source on GitHub ↗

Construct a new tensor by computing over the shape domain. The compute rule is result[axis] = fcompute(axis) Parameters ---------- shape: Tuple of Expr The shape of the tensor fcompute: lambda function of indices-> value Specifies the input source expression

(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=None)

Source from the content-addressed store, hash-verified

57
58
59def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=None):
60 """Construct a new tensor by computing over the shape domain.
61
62 The compute rule is result[axis] = fcompute(axis)
63
64 Parameters
65 ----------
66 shape: Tuple of Expr
67 The shape of the tensor
68
69 fcompute: lambda function of indices-> value
70 Specifies the input source expression
71
72 name: str, optional
73 The name hint of the tensor
74
75 tag: str, optional
76 Additional tag information about the compute.
77
78 attrs: dict, optional
79 The additional auxiliary attributes about the compute.
80
81 varargs_names: list, optional
82 The names to use for each of the varargs. If not supplied, the varargs
83 will be called i1, i2, ...
84
85 Returns
86 -------
87 tensor: Tensor
88 The created tensor
89 """
90 if _tag.TagScope.get_current() is not None:
91 if tag != "":
92 raise ValueError("nested tag is not allowed for now")
93 tag = _tag.TagScope.get_current().tag
94 shape = (shape,) if isinstance(shape, tvm.tirx.PrimExpr) else shape
95 # for python3
96 shape = tuple([int(s) if isinstance(s, float) else s for s in shape])
97 out_ndim = len(shape)
98
99 argspec = inspect.getfullargspec(fcompute)
100 if len(argspec.args) == 0 and argspec.varargs is None:
101 arg_names = [f"i{i}" for i in range(out_ndim)]
102 elif argspec.varargs is not None:
103 # if there is a varargs, it takes the remaining dimensions of out_ndim
104 num_remaining_args = out_ndim - len(argspec.args)
105 if varargs_names is not None:
106 if len(varargs_names) != num_remaining_args:
107 raise RuntimeError(
108 f"Number of varargs ({num_remaining_args}) does not match number"
109 f"of varargs_names ({len(varargs_names)})"
110 )
111 arg_names = argspec.args + varargs_names
112 else:
113 arg_names = argspec.args + [f"i{i}" for i in range(out_ndim - len(argspec.args))]
114 else:
115 arg_names = argspec.args
116 # if there are fewer args than out dimensions, the remaining dimensions

Callers

nothing calls this directly

Calls 5

tupleFunction · 0.85
fcomputeFunction · 0.85
get_currentMethod · 0.80
outputMethod · 0.80
convertFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…