MCPcopy
hub / github.com/dmlc/dgl / _gspmm

Function _gspmm

python/dgl/_sparse_ops.py:156–265  ·  view source on GitHub ↗

r"""Generalized Sparse Matrix Multiplication interface. It takes the result of :attr:`op` on source node feature and edge feature, leads to a message on edge. Then aggregates the message by :attr:`reduce_op` on destination nodes. .. math:: x_v = \psi_{(u, v, e)\in \mathcal{G}}(\

(gidx, op, reduce_op, u, e)

Source from the content-addressed store, hash-verified

154
155
156def _gspmm(gidx, op, reduce_op, u, e):
157 r"""Generalized Sparse Matrix Multiplication interface. It takes the result of
158 :attr:`op` on source node feature and edge feature, leads to a message on edge.
159 Then aggregates the message by :attr:`reduce_op` on destination nodes.
160
161 .. math::
162 x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
163
164 where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
165 :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
166 operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
167 :math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
168
169 Note that this function does not handle gradients.
170
171 Parameters
172 ----------
173 gidx : HeteroGraphIndex
174 The input graph index.
175 op : str
176 The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``,
177 ``copy_rhs``.
178 reduce_op : str
179 Reduce operator, could be ``sum``, ``max``, ``min``.
180 u : tensor or None
181 The feature on source nodes, could be None if op is ``copy_rhs``.
182 e : tensor or None
183 The feature on edges, could be None if op is ``copy_lhs``.
184
185 Returns
186 -------
187 tuple
188 The returned tuple is composed of two elements:
189 - The first element refers to the result tensor.
190 - The second element refers to a tuple composed of arg_u and arg_e
191 (which is useful when reducer is `min`/`max`).
192
193 Notes
194 -----
195 This function does not handle gradients.
196 """
197 if gidx.number_of_etypes() != 1:
198 raise DGLError("We only support gspmm on graph with one edge type")
199 use_u = op != "copy_rhs"
200 use_e = op != "copy_lhs"
201 if use_u and use_e:
202 if F.dtype(u) != F.dtype(e):
203 raise DGLError(
204 "The node features' data type {} doesn't match edge"
205 " features' data type {}, please convert them to the"
206 " same type.".format(F.dtype(u), F.dtype(e))
207 )
208 # deal with scalar features.
209 expand_u, expand_e = False, False
210 if use_u:
211 if F.ndim(u) == 1:
212 u = F.unsqueeze(u, -1)
213 expand_u = True

Callers 12

forwardMethod · 0.90
backwardMethod · 0.90
forwardMethod · 0.85
backwardMethod · 0.85
backwardMethod · 0.85
forwardMethod · 0.85
gspmm_realFunction · 0.85
gradFunction · 0.85
edge_softmax_realFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85

Calls 12

DGLErrorClass · 0.85
infer_broadcast_shapeFunction · 0.85
to_dgl_nd_for_writeFunction · 0.85
number_of_etypesMethod · 0.80
formatMethod · 0.80
contextMethod · 0.80
find_edgeMethod · 0.80
to_dgl_ndFunction · 0.70
dtypeMethod · 0.45
shapeMethod · 0.45
num_nodesMethod · 0.45
num_edgesMethod · 0.45

Tested by

no test coverage detected