Efficient version of torch.cat that avoids a copy if there is only a single element in a list
(tensors: List[torch.Tensor], dim: int = 0)
| 36 | |
| 37 | |
| 38 | def cat(tensors: List[torch.Tensor], dim: int = 0): |
| 39 | """ |
| 40 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list |
| 41 | """ |
| 42 | assert isinstance(tensors, (list, tuple)) |
| 43 | if len(tensors) == 1: |
| 44 | return tensors[0] |
| 45 | return torch.cat(tensors, dim) |
| 46 | |
| 47 | |
| 48 | def cross_entropy(input, target, *, reduction="mean", **kwargs): |
no test coverage detected