(graph, x, actnn=False, fast_spmm=None, fast_spmm_cpu=None)
| 83 | |
| 84 | |
| 85 | def spmm(graph, x, actnn=False, fast_spmm=None, fast_spmm_cpu=None): |
| 86 | if hasattr(graph, "grb_adj") and graph.grb_adj is not None: |
| 87 | if graph.grb_adj.is_sparse: |
| 88 | x = torch.sparse.mm(graph.grb_adj, x) |
| 89 | else: |
| 90 | x = torch.mm(graph.grb_adj, x) |
| 91 | return x |
| 92 | if fast_spmm is None: |
| 93 | initialize_spmm() |
| 94 | fast_spmm = CONFIGS["fast_spmm"] |
| 95 | if fast_spmm_cpu is None: |
| 96 | initialize_spmm_cpu() |
| 97 | fast_spmm_cpu = CONFIGS["fast_spmm_cpu"] |
| 98 | if fast_spmm is not None and str(x.device) != "cpu": |
| 99 | if graph.out_norm is not None: |
| 100 | x = graph.out_norm * x |
| 101 | |
| 102 | row_ptr, col_indices = graph.row_indptr, graph.col_indices |
| 103 | csr_data = graph.raw_edge_weight |
| 104 | if x.dtype == torch.half: |
| 105 | csr_data = csr_data.half() |
| 106 | x = fast_spmm(row_ptr.int(), col_indices.int(), x, csr_data, graph.is_symmetric(), actnn=actnn) |
| 107 | |
| 108 | if graph.in_norm is not None: |
| 109 | x = graph.in_norm * x |
| 110 | elif fast_spmm_cpu is not None and str(x.device) == "cpu" and x.requires_grad is False: |
| 111 | if graph.out_norm is not None: |
| 112 | x = graph.out_norm * x |
| 113 | |
| 114 | row_ptr, col_indices = graph.row_indptr, graph.col_indices |
| 115 | csr_data = graph.raw_edge_weight |
| 116 | x = fast_spmm_cpu(row_ptr.int(), col_indices.int(), csr_data, x) |
| 117 | |
| 118 | if graph.in_norm is not None: |
| 119 | x = graph.in_norm * x |
| 120 | else: |
| 121 | row, col = graph.edge_index |
| 122 | x = spmm_scatter(row, col, graph.edge_weight, x) |
| 123 | return x |
| 124 | |
| 125 | |
| 126 | class SpMM(torch.nn.Module): |
no test coverage detected