r"""Generalized Sparse Matrix Multiplication interface. It takes the result of :attr:`op` on source node feature and edge feature, leads to a message on edge. Then aggregates the message by :attr:`reduce_op` on destination nodes. .. math:: x_v = \psi_{(u, v, e)\in \mathcal{G}}(\
(gidx, op, reduce_op, u, e)
| 154 | |
| 155 | |
| 156 | def _gspmm(gidx, op, reduce_op, u, e): |
| 157 | r"""Generalized Sparse Matrix Multiplication interface. It takes the result of |
| 158 | :attr:`op` on source node feature and edge feature, leads to a message on edge. |
| 159 | Then aggregates the message by :attr:`reduce_op` on destination nodes. |
| 160 | |
| 161 | .. math:: |
| 162 | x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e)) |
| 163 | |
| 164 | where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`, |
| 165 | :math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary |
| 166 | operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`, |
| 167 | :math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`. |
| 168 | |
| 169 | Note that this function does not handle gradients. |
| 170 | |
| 171 | Parameters |
| 172 | ---------- |
| 173 | gidx : HeteroGraphIndex |
| 174 | The input graph index. |
| 175 | op : str |
| 176 | The binary op's name, could be ``add``, ``sub``, ``mul``, ``div``, ``copy_lhs``, |
| 177 | ``copy_rhs``. |
| 178 | reduce_op : str |
| 179 | Reduce operator, could be ``sum``, ``max``, ``min``. |
| 180 | u : tensor or None |
| 181 | The feature on source nodes, could be None if op is ``copy_rhs``. |
| 182 | e : tensor or None |
| 183 | The feature on edges, could be None if op is ``copy_lhs``. |
| 184 | |
| 185 | Returns |
| 186 | ------- |
| 187 | tuple |
| 188 | The returned tuple is composed of two elements: |
| 189 | - The first element refers to the result tensor. |
| 190 | - The second element refers to a tuple composed of arg_u and arg_e |
| 191 | (which is useful when reducer is `min`/`max`). |
| 192 | |
| 193 | Notes |
| 194 | ----- |
| 195 | This function does not handle gradients. |
| 196 | """ |
| 197 | if gidx.number_of_etypes() != 1: |
| 198 | raise DGLError("We only support gspmm on graph with one edge type") |
| 199 | use_u = op != "copy_rhs" |
| 200 | use_e = op != "copy_lhs" |
| 201 | if use_u and use_e: |
| 202 | if F.dtype(u) != F.dtype(e): |
| 203 | raise DGLError( |
| 204 | "The node features' data type {} doesn't match edge" |
| 205 | " features' data type {}, please convert them to the" |
| 206 | " same type.".format(F.dtype(u), F.dtype(e)) |
| 207 | ) |
| 208 | # deal with scalar features. |
| 209 | expand_u, expand_e = False, False |
| 210 | if use_u: |
| 211 | if F.ndim(u) == 1: |
| 212 | u = F.unsqueeze(u, -1) |
| 213 | expand_u = True |
no test coverage detected