(msg, x, y)
| 38 | |
| 39 | |
| 40 | def binary_op(msg, x, y): |
| 41 | if msg == "add": |
| 42 | return x + y |
| 43 | elif msg == "sub": |
| 44 | return x - y |
| 45 | elif msg == "mul": |
| 46 | return x * y |
| 47 | elif msg == "div": |
| 48 | return x / y |
| 49 | elif msg == "dot": |
| 50 | return F.sum(x * y, -1, keepdims=True) |
| 51 | elif msg == "copy_lhs": |
| 52 | return x |
| 53 | elif msg == "copy_rhs": |
| 54 | return y |
| 55 | |
| 56 | |
| 57 | def edge_func(lhs_target, rhs_target, msg): |