MCPcopy
hub / github.com/dmlc/dgl / index_select

Method index_select

python/dgl/sparse/sparse_matrix.py:482–531  ·  view source on GitHub ↗

Returns a sub-matrix selected according to the given index. Parameters ---------- dim : int The dim to select from matrix, should be 0 or 1. `dim = 0` for rowwise selection and `dim = 1` for columnwise selection. index : torch.Tensor

(self, dim: int, index: torch.Tensor)

Source from the content-addressed store, hash-verified

480 return self.c_sparse_matrix.is_diag()
481
482 def index_select(self, dim: int, index: torch.Tensor):
483 """Returns a sub-matrix selected according to the given index.
484
485 Parameters
486 ----------
487 dim : int
488 The dim to select from matrix, should be 0 or 1. `dim = 0` for
489 rowwise selection and `dim = 1` for columnwise selection.
490 index : torch.Tensor
491 The selection index indicates which IDs from the `dim` should
492 be chosen from the matrix.
493 Note that duplicated ids are allowed.
494
495 The function does not support autograd.
496
497 Returns
498 -------
499 SparseMatrix
500 The sub-matrix which contains selected rows or columns.
501
502 Examples
503 --------
504
505 >>> indices = torch.tensor([0, 1, 1, 2, 3, 4], [0, 2, 4, 3, 5, 0]])
506 >>> val = torch.tensor([0, 1, 2, 3, 4, 5])
507 >>> A = dglsp.spmatrix(indices, val)
508
509 Case 1: Select rows by IDs.
510
511 >>> row_ids = torch.tensor([0, 1, 4])
512 >>> A.index_select(0, row_ids)
513 SparseMatrix(indices=tensor([[0, 1, 1, 2],
514 [0, 2, 4, 0]]),
515 values=tensor([0, 1, 2, 5]),
516 shape=(3, 6), nnz=4)
517
518 Case 2: Select columns by IDs.
519
520 >>> column_ids = torch.tensor([0, 4, 5])
521 >>> A.index_select(1, column_ids)
522 SparseMatrix(indices=tensor([[0, 4, 1, 3],
523 [0, 0, 1, 2]]),
524 values=tensor([0, 5, 2, 4]),
525 shape=(5, 3), nnz=4)
526 """
527 if dim not in (0, 1):
528 raise ValueError("The selection dimension should be 0 or 1.")
529 if isinstance(index, torch.Tensor):
530 return SparseMatrix(self.c_sparse_matrix.index_select(dim, index))
531 raise TypeError(f"{type(index).__name__} is unsupported input type.")
532
533 def range_select(self, dim: int, index: slice):
534 """Returns a sub-matrix selected according to the given range index.

Callers 15

gather_mmFunction · 0.80
gather_rowFunction · 0.80
takeFunction · 0.80
_fetch_cpuFunction · 0.80
_fetch_cudaFunction · 0.80
forwardMethod · 0.80
matmul_maybe_selectFunction · 0.80
bmm_maybe_selectFunction · 0.80
_to_reverse_idsFunction · 0.80
index_selectFunction · 0.80
from_dglgraphFunction · 0.80

Calls 1

SparseMatrixClass · 0.70

Tested by 4

test_compactFunction · 0.64
test_index_selectFunction · 0.64
test_index_selectFunction · 0.64
test_feature_cacheFunction · 0.64