Efficient version of torch.cat that avoids a copy if there is only a single element in a list
(tensors: List[torch.Tensor], dim: int = 0)
| 48 | |
| 49 | |
| 50 | def cat(tensors: List[torch.Tensor], dim: int = 0): |
| 51 | """ |
| 52 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list |
| 53 | """ |
| 54 | assert isinstance(tensors, (list, tuple)) |
| 55 | if len(tensors) == 1: |
| 56 | return tensors[0] |
| 57 | return torch.cat(tensors, dim) |
| 58 | |
| 59 | |
| 60 | def empty_input_loss_func_wrapper(loss_func): |
no test coverage detected