MCPcopy
hub / github.com/THUDM/CogDL / spmm

Function spmm

cogdl/utils/spmm_utils.py:85–123  ·  view source on GitHub ↗
(graph, x, actnn=False, fast_spmm=None, fast_spmm_cpu=None)

Source from the content-addressed store, hash-verified

83
84
85def spmm(graph, x, actnn=False, fast_spmm=None, fast_spmm_cpu=None):
86 if hasattr(graph, "grb_adj") and graph.grb_adj is not None:
87 if graph.grb_adj.is_sparse:
88 x = torch.sparse.mm(graph.grb_adj, x)
89 else:
90 x = torch.mm(graph.grb_adj, x)
91 return x
92 if fast_spmm is None:
93 initialize_spmm()
94 fast_spmm = CONFIGS["fast_spmm"]
95 if fast_spmm_cpu is None:
96 initialize_spmm_cpu()
97 fast_spmm_cpu = CONFIGS["fast_spmm_cpu"]
98 if fast_spmm is not None and str(x.device) != "cpu":
99 if graph.out_norm is not None:
100 x = graph.out_norm * x
101
102 row_ptr, col_indices = graph.row_indptr, graph.col_indices
103 csr_data = graph.raw_edge_weight
104 if x.dtype == torch.half:
105 csr_data = csr_data.half()
106 x = fast_spmm(row_ptr.int(), col_indices.int(), x, csr_data, graph.is_symmetric(), actnn=actnn)
107
108 if graph.in_norm is not None:
109 x = graph.in_norm * x
110 elif fast_spmm_cpu is not None and str(x.device) == "cpu" and x.requires_grad is False:
111 if graph.out_norm is not None:
112 x = graph.out_norm * x
113
114 row_ptr, col_indices = graph.row_indptr, graph.col_indices
115 csr_data = graph.raw_edge_weight
116 x = fast_spmm_cpu(row_ptr.int(), col_indices.int(), csr_data, x)
117
118 if graph.in_norm is not None:
119 x = graph.in_norm * x
120 else:
121 row, col = graph.edge_index
122 x = spmm_scatter(row, col, graph.edge_weight, x)
123 return x
124
125
126class SpMM(torch.nn.Module):

Callers 15

forwardMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90
multi_hop_sgcFunction · 0.90
multi_hop_ppr_diffusionFunction · 0.90
rand_propMethod · 0.90
forwardMethod · 0.90
inferenceMethod · 0.90
inference_batchMethod · 0.90
homo_indexFunction · 0.90
__call__Method · 0.90
__call__Method · 0.90

Calls 4

initialize_spmmFunction · 0.85
initialize_spmm_cpuFunction · 0.85
spmm_scatterFunction · 0.85
is_symmetricMethod · 0.45

Tested by

no test coverage detected