| 304 | return out_tensors |
| 305 | |
| 306 | def _infer_concrete_shape_from_args(self, shape, in_args): |
| 307 | concrete = [] |
| 308 | symbolic_positions = [] |
| 309 | for idx, dim in enumerate(shape): |
| 310 | if isinstance(dim, int | np.integer): |
| 311 | concrete.append(int(dim)) |
| 312 | elif isinstance(dim, tirx.IntImm): |
| 313 | concrete.append(int(dim.value)) |
| 314 | else: |
| 315 | concrete.append(None) |
| 316 | symbolic_positions.append(idx) |
| 317 | |
| 318 | if not symbolic_positions: |
| 319 | return concrete |
| 320 | |
| 321 | candidates = [] |
| 322 | if in_args is not None: |
| 323 | if not isinstance(in_args, list | tuple): |
| 324 | in_args = [in_args] |
| 325 | for obj in in_args: |
| 326 | if hasattr(obj, "shape") and isinstance(obj.shape, tuple | list): |
| 327 | try: |
| 328 | candidates.append(tuple(int(x) for x in obj.shape)) |
| 329 | continue |
| 330 | except (ValueError, TypeError): |
| 331 | # Skip objects with invalid shapes |
| 332 | pass |
| 333 | |
| 334 | target_ndim = len(shape) |
| 335 | for cand in candidates: |
| 336 | if len(cand) == target_ndim: |
| 337 | for pos in symbolic_positions: |
| 338 | concrete[pos] = cand[pos] |
| 339 | if all(x is not None for x in concrete): |
| 340 | return concrete |
| 341 | |
| 342 | raise ValueError( |
| 343 | "Cannot infer concrete output shape from symbolic shape and inputs. " |
| 344 | "Please provide a concrete `out_sinfo` (e.g., a tuple/list of ints) " |
| 345 | "or ensure input tensors carry shapes that determine output extents." |
| 346 | ) |
| 347 | |
| 348 | def _convert_tvm_dtype_to_torch(self, tvm_dtype: str) -> "torch.dtype": |
| 349 | """Convert TVM dtype string to PyTorch dtype.""" |