Add an operation to transpose two dimensions of a tensor. That operation produces a tensor in which the dimensions 'dim0' and 'dim1' are permuted. The other dimensions, if the rank of the tensor is greater than 2, remain untouched. That function is a helper built on the 'funct
(input: Tensor, dim0: int, dim1: int)
| 1713 | |
| 1714 | |
| 1715 | def transpose(input: Tensor, dim0: int, dim1: int) -> Tensor: |
| 1716 | ''' |
| 1717 | Add an operation to transpose two dimensions of a tensor. |
| 1718 | |
| 1719 | That operation produces a tensor in which the dimensions 'dim0' and 'dim1' |
| 1720 | are permuted. The other dimensions, if the rank of the tensor is greater |
| 1721 | than 2, remain untouched. |
| 1722 | |
| 1723 | That function is a helper built on the 'functional.permute' function. |
| 1724 | |
| 1725 | Parameters: |
| 1726 | input : Tensor |
| 1727 | The input tensor to transpose. |
| 1728 | |
| 1729 | dim0 : int |
| 1730 | The first dimension to transpose. |
| 1731 | |
| 1732 | dim1 : int |
| 1733 | The second dimension to transpose. |
| 1734 | |
| 1735 | Returns: |
| 1736 | The tensor produced by the permutation layer. |
| 1737 | ''' |
| 1738 | permutation = list(range(input.ndim())) |
| 1739 | permutation[dim0] = dim1 |
| 1740 | permutation[dim1] = dim0 |
| 1741 | |
| 1742 | return permute(input, permutation) |
| 1743 | |
| 1744 | |
| 1745 | def view(input: Tensor, |
no test coverage detected