(idtype, dtype, A_nnz, B_nnz)
| 201 | @pytest.mark.parametrize("A_nnz", [9000, 0]) |
| 202 | @pytest.mark.parametrize("B_nnz", [9000, 0]) |
| 203 | def test_csrmask(idtype, dtype, A_nnz, B_nnz): |
| 204 | a, A = _random_simple_graph( |
| 205 | idtype, dtype, F.ctx(), 500, 600, A_nnz, "A", "B", "AB" |
| 206 | ) |
| 207 | b, B = _random_simple_graph( |
| 208 | idtype, dtype, F.ctx(), 500, 600, B_nnz, "A", "B", "AB" |
| 209 | ) |
| 210 | C = dgl._sparse_ops._csrmask(A._graph, A.edata["w"], B._graph) |
| 211 | B_row, B_col = B.edges(order="eid") |
| 212 | B_row = F.asnumpy(B_row) |
| 213 | B_col = F.asnumpy(B_col) |
| 214 | c = F.tensor(a.todense()[B_row, B_col], dtype) |
| 215 | assert F.allclose(C, c) |
| 216 | |
| 217 | |
| 218 | @parametrize_idtype |
no test coverage detected