Creates a sparse matrix based on the diagonal values. Parameters ---------- val : torch.Tensor Diagonal of the matrix, in shape ``(N)`` or ``(N, D)`` shape : tuple[int, int], optional If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`, otherwi
(
val: torch.Tensor, shape: Optional[Tuple[int, int]] = None
)
| 1144 | |
| 1145 | |
| 1146 | def diag( |
| 1147 | val: torch.Tensor, shape: Optional[Tuple[int, int]] = None |
| 1148 | ) -> SparseMatrix: |
| 1149 | """Creates a sparse matrix based on the diagonal values. |
| 1150 | |
| 1151 | Parameters |
| 1152 | ---------- |
| 1153 | val : torch.Tensor |
| 1154 | Diagonal of the matrix, in shape ``(N)`` or ``(N, D)`` |
| 1155 | shape : tuple[int, int], optional |
| 1156 | If specified, :attr:`len(val)` must be equal to :attr:`min(shape)`, |
| 1157 | otherwise, it will be inferred from :attr:`val`, i.e., ``(N, N)`` |
| 1158 | |
| 1159 | Returns |
| 1160 | ------- |
| 1161 | SparseMatrix |
| 1162 | Sparse matrix |
| 1163 | |
| 1164 | Examples |
| 1165 | -------- |
| 1166 | |
| 1167 | Case1: 5-by-5 diagonal matrix with scaler values on the diagonal |
| 1168 | |
| 1169 | >>> import torch |
| 1170 | >>> val = torch.ones(5) |
| 1171 | >>> dglsp.diag(val) |
| 1172 | SparseMatrix(indices=tensor([[0, 1, 2, 3, 4], |
| 1173 | [0, 1, 2, 3, 4]]), |
| 1174 | values=tensor([1., 1., 1., 1., 1.]), |
| 1175 | shape=(5, 5), nnz=5) |
| 1176 | |
| 1177 | Case2: 5-by-10 diagonal matrix with scaler values on the diagonal |
| 1178 | |
| 1179 | >>> val = torch.ones(5) |
| 1180 | >>> dglsp.diag(val, shape=(5, 10)) |
| 1181 | SparseMatrix(indices=tensor([[0, 1, 2, 3, 4], |
| 1182 | [0, 1, 2, 3, 4]]), |
| 1183 | values=tensor([1., 1., 1., 1., 1.]), |
| 1184 | shape=(5, 10), nnz=5) |
| 1185 | |
| 1186 | Case3: 5-by-5 diagonal matrix with vector values on the diagonal |
| 1187 | |
| 1188 | >>> val = torch.randn(5, 3) |
| 1189 | >>> D = dglsp.diag(val) |
| 1190 | >>> D.shape |
| 1191 | (5, 5) |
| 1192 | >>> D.nnz |
| 1193 | 5 |
| 1194 | """ |
| 1195 | assert ( |
| 1196 | val.dim() <= 2 |
| 1197 | ), "The values of a DiagMatrix can only be scalars or vectors." |
| 1198 | len_val = len(val) |
| 1199 | if shape is not None: |
| 1200 | assert len_val == min(shape), ( |
| 1201 | f"Expect len(val) to be min(shape) for a diagonal matrix, got" |
| 1202 | f"{len_val} for len(val) and {shape} for shape." |
| 1203 | ) |