A wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more convenient access shape and dtype information. Tensor is always symbolc and not bound to any concrete values. Shape and dtype inference is done eagerly upon tensor creation, i.e. when operators are app
| 88 | |
| 89 | |
| 90 | class Tensor(_TensorOp): |
| 91 | """A wrapper on top of relax.Expr whose struct_info is a TensorStructInfo, providing more |
| 92 | convenient access shape and dtype information. Tensor is always symbolc and not bound to any |
| 93 | concrete values. Shape and dtype inference is done eagerly upon tensor creation, i.e. when |
| 94 | operators are applied on tensors, the shape and dtype information is already available. |
| 95 | """ |
| 96 | |
| 97 | _expr: rx.Expr |
| 98 | |
| 99 | def __init__(self, *, _expr: rx.Expr) -> None: |
| 100 | """Private constructor. Tensor is never supposed to be constructed directly by users.""" |
| 101 | |
| 102 | def _check_tensor(expr: rx.Expr) -> None: |
| 103 | assert expr.struct_info_ is not None |
| 104 | assert isinstance(expr.struct_info, TensorStructInfo) |
| 105 | assert expr.struct_info.ndim != -1 |
| 106 | assert expr.struct_info.shape is not None |
| 107 | assert expr.struct_info.shape.struct_info_ is not None |
| 108 | assert isinstance(expr.struct_info.shape.struct_info, ShapeStructInfo) |
| 109 | assert expr.struct_info.shape.struct_info.values is not None |
| 110 | |
| 111 | _check_tensor(_expr) |
| 112 | self._expr = _expr |
| 113 | |
| 114 | @staticmethod |
| 115 | def from_const(data) -> "Tensor": |
| 116 | """Construct a tensor from numpy constants.""" |
| 117 | return Tensor(_expr=rx.const(data)) |
| 118 | |
| 119 | @staticmethod |
| 120 | def from_scalar(data: int | float, dtype: str) -> "Tensor": |
| 121 | """Construct a tensor from a scalar with dtype specified.""" |
| 122 | return Tensor(_expr=rx.const(data, dtype=dtype)) |
| 123 | |
| 124 | @staticmethod |
| 125 | def from_struct_info(struct_info: rx.TensorStructInfo, name: str = "tensor") -> "Tensor": |
| 126 | """Construct a nn.Tensor from a Relax TensorStructInfo. |
| 127 | |
| 128 | TensorStructInfo is the Relax type-level description of a tensor, carrying its shape |
| 129 | and dtype without holding actual data. This factory creates an unbound placeholder |
| 130 | ``nn.Tensor`` that can be used as a symbolic input when tracing an ``nn.Module``. |
| 131 | |
| 132 | Parameters |
| 133 | ---------- |
| 134 | struct_info : rx.TensorStructInfo |
| 135 | The struct info describing the tensor's shape and dtype. |
| 136 | |
| 137 | name : str |
| 138 | Name hint for the underlying Relax variable. |
| 139 | |
| 140 | Returns |
| 141 | ------- |
| 142 | tensor : Tensor |
| 143 | A symbolic ``nn.Tensor`` backed by a ``relax.Var`` with the given struct info. |
| 144 | """ |
| 145 | return Tensor( |
| 146 | _expr=rx.Var( |
| 147 | name_hint=name, |
no outgoing calls
no test coverage detected
searching dependent graphs…