Convert a scipy sparse matrix to a torch sparse tensor.
(sparse_mx)
| 71 | |
| 72 | |
| 73 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): |
| 74 | """Convert a scipy sparse matrix to a torch sparse tensor.""" |
| 75 | sparse_mx = sparse_mx.tocoo().astype(np.float32) |
| 76 | indices = torch.from_numpy( |
| 77 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) |
| 78 | values = torch.from_numpy(sparse_mx.data) |
| 79 | shape = torch.Size(sparse_mx.shape) |
| 80 | return torch.sparse.FloatTensor(indices, values, shape) |