(torch_tensor, dtype=None)
| 202 | |
| 203 | |
| 204 | def wrap_torch_tensor(torch_tensor, dtype=None): |
| 205 | if dtype is None: |
| 206 | dtype = torch_tensor.dtype |
| 207 | shape = list(torch_tensor.shape) |
| 208 | shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype) |
| 209 | return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape) |
| 210 | |
| 211 | |
| 212 | def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs): |
no test coverage detected