Wrap the given relax.Expr, emit it using the current BlockBuilder, and automatically handle nested cases if the expr represents a Tuple. Parameters ---------- expr : relax.Expr The Expr to be wrapped. name : str Name hint. Returns ------- result : U
(expr: rx.Expr, name: str)
| 760 | |
| 761 | |
| 762 | def wrap_nested(expr: rx.Expr, name: str) -> Tensor | Sequence[Tensor]: |
| 763 | """Wrap the given relax.Expr, emit it using the current BlockBuilder, |
| 764 | and automatically handle nested cases if the expr represents a Tuple. |
| 765 | |
| 766 | Parameters |
| 767 | ---------- |
| 768 | expr : relax.Expr |
| 769 | The Expr to be wrapped. |
| 770 | |
| 771 | name : str |
| 772 | Name hint. |
| 773 | |
| 774 | Returns |
| 775 | ------- |
| 776 | result : Union[Tensor, Tuple[Tensor]] |
| 777 | The computed result. |
| 778 | """ |
| 779 | if not isinstance(expr, rx.DataflowVar): |
| 780 | expr = BlockBuilder.current().emit(expr, name) |
| 781 | if isinstance(expr.struct_info_, TensorStructInfo): |
| 782 | return Tensor(_expr=expr) |
| 783 | if isinstance(expr.struct_info_, TupleStructInfo): |
| 784 | return tuple( |
| 785 | wrap_nested( # type: ignore |
| 786 | rx.TupleGetItem(expr, i), |
| 787 | name=f"{name}.{i}", |
| 788 | ) |
| 789 | for i in range(len(expr.struct_info_.fields)) |
| 790 | ) |
| 791 | raise TypeError(f"Unsupported return type: {expr.struct_info_}") |
| 792 | |
| 793 | |
| 794 | def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]): |
no test coverage detected
searching dependent graphs…