Returns a sampled matrix on the given dimension and sample arguments. Parameters ---------- dim : int The dimension for sampling, should be 0 or 1. `dim = 0` for rowwise selection and `dim = 1` for columnwise selection. fanout : int
(
self,
dim: int,
fanout: int,
ids: Optional[torch.Tensor] = None,
replace: Optional[bool] = False,
bias: Optional[bool] = False,
)
| 587 | raise TypeError(f"{type(index).__name__} is unsupported input type.") |
| 588 | |
| 589 | def sample( |
| 590 | self, |
| 591 | dim: int, |
| 592 | fanout: int, |
| 593 | ids: Optional[torch.Tensor] = None, |
| 594 | replace: Optional[bool] = False, |
| 595 | bias: Optional[bool] = False, |
| 596 | ): |
| 597 | """Returns a sampled matrix on the given dimension and sample arguments. |
| 598 | |
| 599 | Parameters |
| 600 | ---------- |
| 601 | dim : int |
| 602 | The dimension for sampling, should be 0 or 1. `dim = 0` for |
| 603 | rowwise selection and `dim = 1` for columnwise selection. |
| 604 | fanout : int |
| 605 | The number of elements to randomly sample on each row or column. |
| 606 | ids : torch.Tensor, optional |
| 607 | An optional tensor containing row or column IDs from which to |
| 608 | sample elements. |
| 609 | NOTE: If `ids` is not provided (i.e., `ids = None`), the function |
| 610 | will sample from all rows or columns. |
| 611 | replace : bool, optional |
| 612 | Indicates whether repeated sampling of the same element is allowed. |
| 613 | When `replace = True`, repeated sampling is permitted; when |
| 614 | `replace = False`, it is not allowed. |
| 615 | NOTE: If `replace = False` and there are fewer elements than |
| 616 | `fanout`, all non-zero elements will be sampled. |
| 617 | bias : bool, optional |
| 618 | A boolean flag indicating whether to enable biasing during sampling. |
| 619 | When `bias = True`, the values of the sparse matrix will be used as |
| 620 | bias weights. |
| 621 | |
| 622 | The function does not support autograd. |
| 623 | |
| 624 | Returns |
| 625 | ------- |
| 626 | SparseMatrix |
| 627 | A submatrix with the same shape as the original matrix, containing |
| 628 | the randomly sampled non-zero elements. |
| 629 | |
| 630 | Examples |
| 631 | -------- |
| 632 | |
| 633 | >>> indices = torch.tensor([[0, 0, 1, 1, 2, 2, 2], |
| 634 | [0, 2, 0, 1, 0, 1, 2]]) |
| 635 | >>> val = torch.tensor([0, 1, 2, 3, 4, 5, 6]) |
| 636 | >>> A = dglsp.spmatrix(indices, val) |
| 637 | |
| 638 | Case 1: Sample rows with the given number and disable repeated sampling. |
| 639 | |
| 640 | >>> row_ids = torch.tensor([0, 2]) |
| 641 | >>> A.sample(0, 2, row_ids) |
| 642 | SparseMatrix(indices=tensor([[0, 0, 1, 1], |
| 643 | [0, 2, 0, 2]]), |
| 644 | values=tensor([0, 1, 4, 6]), |
| 645 | shape=(2, 3), nnz=4) |
| 646 |
nothing calls this directly
no test coverage detected