(x: torch.Tensor,
tp_size: int,
idx: int,
dim: int = 0)
| 392 | |
| 393 | |
| 394 | def split(x: torch.Tensor, |
| 395 | tp_size: int, |
| 396 | idx: int, |
| 397 | dim: int = 0) -> torch.Tensor: |
| 398 | assert x.shape[dim] % tp_size == 0 |
| 399 | split_size = x.shape[dim] // tp_size |
| 400 | if tp_size == 1: |
| 401 | return x |
| 402 | return torch.split(x, split_size, dim=dim)[idx] |
| 403 | |
| 404 | |
| 405 | def relu2(x: torch.Tensor) -> torch.Tensor: |
no test coverage detected