| 279 | |
| 280 | |
| 281 | def edge_softmax_real(gidx, score, eids=ALL, norm_by="dst"): |
| 282 | if not is_all(eids): |
| 283 | gidx = gidx.edge_subgraph([eids], True).graph |
| 284 | if norm_by == "src": |
| 285 | gidx = gidx.reverse() |
| 286 | score_max = _gspmm(gidx, "copy_rhs", "max", None, score)[0] |
| 287 | score = tf.math.exp(_gsddmm(gidx, "sub", score, score_max, "e", "v")) |
| 288 | score_sum = _gspmm(gidx, "copy_rhs", "sum", None, score)[0] |
| 289 | out = _gsddmm(gidx, "div", score, score_sum, "e", "v") |
| 290 | |
| 291 | def edge_softmax_backward(grad_out): |
| 292 | sds = out * grad_out |
| 293 | accum = gspmm(gidx, "copy_rhs", "sum", None, sds) |
| 294 | grad_score = sds - gsddmm(gidx, "mul", out, accum, "e", "v") |
| 295 | return grad_score |
| 296 | |
| 297 | return out, edge_softmax_backward |
| 298 | |
| 299 | |
| 300 | def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"): |