(*args, **kwargs)
| 676 | |
| 677 | def wrap_fxn(k,f): |
| 678 | def nf(*args, **kwargs): |
| 679 | if TORCH_DEBUG: |
| 680 | print(k, len(args), [x.shape if isinstance(x, torch.Tensor) else x for x in args], |
| 681 | {k:v.shape if isinstance(v, torch.Tensor) else v for k,v in kwargs.items()}) |
| 682 | args, kwargs = unwrap_args(args, kwargs) |
| 683 | out = f(*args, **kwargs) |
| 684 | if isinstance(out, Tensor): return wrap(out) |
| 685 | elif isinstance(out, tuple): return tuple(wrap(x) for x in out) |
| 686 | else: raise RuntimeError(f"unknown output type {type(out)}") |
| 687 | return nf |
| 688 | |
| 689 | def wrap_inplace(k,f): |
nothing calls this directly
no test coverage detected