MCPcopy
hub / github.com/dmlc/dgl / from_torch_sparse

Function from_torch_sparse

python/dgl/sparse/sparse_matrix.py:1284–1339  ·  view source on GitHub ↗

Creates a sparse matrix from a torch sparse tensor, which can have coo, csr, or csc layout. Parameters ---------- torch_sparse_tensor : torch.Tensor Torch sparse tensor Returns ------- SparseMatrix Sparse matrix Examples -------- >>> indice

(torch_sparse_tensor: torch.Tensor)

Source from the content-addressed store, hash-verified

1282
1283
1284def from_torch_sparse(torch_sparse_tensor: torch.Tensor) -> SparseMatrix:
1285 """Creates a sparse matrix from a torch sparse tensor, which can have coo,
1286 csr, or csc layout.
1287
1288 Parameters
1289 ----------
1290 torch_sparse_tensor : torch.Tensor
1291 Torch sparse tensor
1292
1293 Returns
1294 -------
1295 SparseMatrix
1296 Sparse matrix
1297
1298 Examples
1299 --------
1300
1301 >>> indices = torch.tensor([[1, 1, 2], [2, 4, 3]])
1302 >>> val = torch.ones(3)
1303 >>> torch_coo = torch.sparse_coo_tensor(indices, val)
1304 >>> dglsp.from_torch_sparse(torch_coo)
1305 SparseMatrix(indices=tensor([[1, 1, 2],
1306 [2, 4, 3]]),
1307 values=tensor([1., 1., 1.]),
1308 shape=(3, 5), nnz=3)
1309 """
1310 assert torch_sparse_tensor.layout in (
1311 torch.sparse_coo,
1312 torch.sparse_csr,
1313 torch.sparse_csc,
1314 ), (
1315 f"Cannot convert Pytorch sparse tensor with layout "
1316 f"{torch_sparse_tensor.layout} to DGL sparse."
1317 )
1318 if torch_sparse_tensor.layout == torch.sparse_coo:
1319 # Use ._indices() and ._values() to access uncoalesced indices and
1320 # values.
1321 return spmatrix(
1322 torch_sparse_tensor._indices(),
1323 torch_sparse_tensor._values(),
1324 torch_sparse_tensor.shape[:2],
1325 )
1326 elif torch_sparse_tensor.layout == torch.sparse_csr:
1327 return from_csr(
1328 torch_sparse_tensor.crow_indices(),
1329 torch_sparse_tensor.col_indices(),
1330 torch_sparse_tensor.values(),
1331 torch_sparse_tensor.shape[:2],
1332 )
1333 else:
1334 return from_csc(
1335 torch_sparse_tensor.ccol_indices(),
1336 torch_sparse_tensor.row_indices(),
1337 torch_sparse_tensor.values(),
1338 torch_sparse_tensor.shape[:2],
1339 )
1340
1341

Calls 4

spmatrixFunction · 0.85
from_cscFunction · 0.85
from_csrFunction · 0.70
valuesMethod · 0.45