Construct a new tensor by computing over the shape domain. The compute rule is result[axis] = fcompute(axis) Parameters ---------- shape: Tuple of Expr The shape of the tensor fcompute: lambda function of indices-> value Specifies the input source expression
(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=None)
| 57 | |
| 58 | |
| 59 | def compute(shape, fcompute, name="compute", tag="", attrs=None, varargs_names=None): |
| 60 | """Construct a new tensor by computing over the shape domain. |
| 61 | |
| 62 | The compute rule is result[axis] = fcompute(axis) |
| 63 | |
| 64 | Parameters |
| 65 | ---------- |
| 66 | shape: Tuple of Expr |
| 67 | The shape of the tensor |
| 68 | |
| 69 | fcompute: lambda function of indices-> value |
| 70 | Specifies the input source expression |
| 71 | |
| 72 | name: str, optional |
| 73 | The name hint of the tensor |
| 74 | |
| 75 | tag: str, optional |
| 76 | Additional tag information about the compute. |
| 77 | |
| 78 | attrs: dict, optional |
| 79 | The additional auxiliary attributes about the compute. |
| 80 | |
| 81 | varargs_names: list, optional |
| 82 | The names to use for each of the varargs. If not supplied, the varargs |
| 83 | will be called i1, i2, ... |
| 84 | |
| 85 | Returns |
| 86 | ------- |
| 87 | tensor: Tensor |
| 88 | The created tensor |
| 89 | """ |
| 90 | if _tag.TagScope.get_current() is not None: |
| 91 | if tag != "": |
| 92 | raise ValueError("nested tag is not allowed for now") |
| 93 | tag = _tag.TagScope.get_current().tag |
| 94 | shape = (shape,) if isinstance(shape, tvm.tirx.PrimExpr) else shape |
| 95 | # for python3 |
| 96 | shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) |
| 97 | out_ndim = len(shape) |
| 98 | |
| 99 | argspec = inspect.getfullargspec(fcompute) |
| 100 | if len(argspec.args) == 0 and argspec.varargs is None: |
| 101 | arg_names = [f"i{i}" for i in range(out_ndim)] |
| 102 | elif argspec.varargs is not None: |
| 103 | # if there is a varargs, it takes the remaining dimensions of out_ndim |
| 104 | num_remaining_args = out_ndim - len(argspec.args) |
| 105 | if varargs_names is not None: |
| 106 | if len(varargs_names) != num_remaining_args: |
| 107 | raise RuntimeError( |
| 108 | f"Number of varargs ({num_remaining_args}) does not match number" |
| 109 | f"of varargs_names ({len(varargs_names)})" |
| 110 | ) |
| 111 | arg_names = argspec.args + varargs_names |
| 112 | else: |
| 113 | arg_names = argspec.args + [f"i{i}" for i in range(out_ndim - len(argspec.args))] |
| 114 | else: |
| 115 | arg_names = argspec.args |
| 116 | # if there are fewer args than out dimensions, the remaining dimensions |
nothing calls this directly
no test coverage detected
searching dependent graphs…