(self, out_sinfo, in_args=None)
| 282 | return py_func(self, *args) |
| 283 | |
| 284 | def _create_output_tensors(self, out_sinfo, in_args=None): |
| 285 | # pylint: disable=import-outside-toplevel |
| 286 | import torch |
| 287 | |
| 288 | sinfo_list = out_sinfo if isinstance(out_sinfo, list) else [out_sinfo] |
| 289 | out_tensors = [] |
| 290 | for sinfo in sinfo_list: |
| 291 | if isinstance(sinfo, tuple | list) and all( |
| 292 | isinstance(x, int | np.integer) for x in sinfo |
| 293 | ): |
| 294 | out_tensors.append(torch.zeros(list(map(int, sinfo)), dtype=torch.float32)) |
| 295 | continue |
| 296 | |
| 297 | if hasattr(sinfo, "shape") and hasattr(sinfo, "dtype"): |
| 298 | concrete_shape = self._infer_concrete_shape_from_args(sinfo.shape, in_args) |
| 299 | torch_dtype = self._convert_tvm_dtype_to_torch(sinfo.dtype) |
| 300 | out_tensors.append(torch.zeros(concrete_shape, dtype=torch_dtype)) |
| 301 | continue |
| 302 | |
| 303 | out_tensors.append(torch.zeros((1,), dtype=torch.float32)) |
| 304 | return out_tensors |
| 305 | |
| 306 | def _infer_concrete_shape_from_args(self, shape, in_args): |
| 307 | concrete = [] |
no test coverage detected