| 328 | |
| 329 | |
| 330 | def test_tuple_get_item_nested(): |
| 331 | # Error: The tuple value in tuple get item must be a leaf expression |
| 332 | nested_tup = rx.Var( |
| 333 | "t", rx.TupleStructInfo([rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])]) |
| 334 | ) |
| 335 | double_idx = rx.TupleGetItem(rx.TupleGetItem(nested_tup, 0), 0) |
| 336 | ret_var = rx.Var("r", R.Tensor([], "int32")) |
| 337 | f = rx.Function( |
| 338 | [nested_tup], |
| 339 | rx.SeqExpr([rx.BindingBlock([rx.VarBinding(ret_var, double_idx)])], ret_var), |
| 340 | ret_struct_info=R.Tensor(ndim=0, dtype="int32"), |
| 341 | ) |
| 342 | f = f.with_attr("global_symbol", "f") |
| 343 | mod = tvm.IRModule.from_expr(f) |
| 344 | assert not rx.analysis.check_well_formed(mod, check_struct_info=False) |
| 345 | |
| 346 | # okay with an intermediate binding |
| 347 | first_idx = rx.TupleGetItem(nested_tup, 0) |
| 348 | idx_var = rx.Var("v", rx.TupleStructInfo([rx.TensorStructInfo([], "int32")])) |
| 349 | second_idx = rx.TupleGetItem(idx_var, 0) |
| 350 | new_f = rx.Function( |
| 351 | [nested_tup], |
| 352 | rx.SeqExpr( |
| 353 | [ |
| 354 | rx.BindingBlock( |
| 355 | [rx.VarBinding(idx_var, first_idx), rx.VarBinding(ret_var, second_idx)] |
| 356 | ) |
| 357 | ], |
| 358 | ret_var, |
| 359 | ), |
| 360 | ret_struct_info=R.Tensor(ndim=0, dtype="int32"), |
| 361 | ) |
| 362 | new_f = new_f.with_attr("global_symbol", "new_f") |
| 363 | mod = tvm.IRModule.from_expr(new_f) |
| 364 | # normalize in order to fill in checked type |
| 365 | normalized = rx.transform.Normalize()(mod) |
| 366 | rx.analysis.well_formed(normalized, check_struct_info=True) |
| 367 | |
| 368 | |
| 369 | def test_complex_seq_body(): |