Creates a sparse matrix from a torch sparse tensor, which can have coo, csr, or csc layout. Parameters ---------- torch_sparse_tensor : torch.Tensor Torch sparse tensor Returns ------- SparseMatrix Sparse matrix Examples -------- >>> indice
(torch_sparse_tensor: torch.Tensor)
| 1282 | |
| 1283 | |
| 1284 | def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix: |
| 1285 | """Creates a sparse matrix from a torch sparse tensor, which can have coo, |
| 1286 | csr, or csc layout. |
| 1287 | |
| 1288 | Parameters |
| 1289 | ---------- |
| 1290 | torch_sparse_tensor : torch.Tensor |
| 1291 | Torch sparse tensor |
| 1292 | |
| 1293 | Returns |
| 1294 | ------- |
| 1295 | SparseMatrix |
| 1296 | Sparse matrix |
| 1297 | |
| 1298 | Examples |
| 1299 | -------- |
| 1300 | |
| 1301 | >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]]) |
| 1302 | >>> val = torch.ones(3) |
| 1303 | >>> torch_coo = torch.sparse_coo_tensor(indices, val) |
| 1304 | >>> dglsp.from_torch_sparse(torch_coo) |
| 1305 | SparseMatrix(indices=tensor([[1, 1, 2], |
| 1306 | [2, 4, 3]]), |
| 1307 | values=tensor([1., 1., 1.]), |
| 1308 | shape=(3, 5), nnz=3) |
| 1309 | """ |
| 1310 | assert torch_sparse_tensor.layout in ( |
| 1311 | torch.sparse_coo, |
| 1312 | torch.sparse_csr, |
| 1313 | torch.sparse_csc, |
| 1314 | ), ( |
| 1315 | f"Cannot convert Pytorch sparse tensor with layout " |
| 1316 | f"{torch_sparse_tensor.layout} to DGL sparse." |
| 1317 | ) |
| 1318 | if torch_sparse_tensor.layout == torch.sparse_coo: |
| 1319 | # Use ._indices() and ._values() to access uncoalesced indices and |
| 1320 | # values. |
| 1321 | return spmatrix( |
| 1322 | torch_sparse_tensor._indices(), |
| 1323 | torch_sparse_tensor._values(), |
| 1324 | torch_sparse_tensor.shape[:2], |
| 1325 | ) |
| 1326 | elif torch_sparse_tensor.layout == torch.sparse_csr: |
| 1327 | return from_csr( |
| 1328 | torch_sparse_tensor.crow_indices(), |
| 1329 | torch_sparse_tensor.col_indices(), |
| 1330 | torch_sparse_tensor.values(), |
| 1331 | torch_sparse_tensor.shape[:2], |
| 1332 | ) |
| 1333 | else: |
| 1334 | return from_csc( |
| 1335 | torch_sparse_tensor.ccol_indices(), |
| 1336 | torch_sparse_tensor.row_indices(), |
| 1337 | torch_sparse_tensor.values(), |
| 1338 | torch_sparse_tensor.shape[:2], |
| 1339 | ) |
| 1340 | |
| 1341 |