(idtype, dtype, g, shp, msg, reducer)
| 117 | @parametrize_idtype |
| 118 | @pytest.mark.parametrize("dtype", [np.float32, np.float64]) |
| 119 | def test_spmm(idtype, dtype, g, shp, msg, reducer): |
| 120 | g = g.astype(idtype).to(F.ctx()) |
| 121 | print(g) |
| 122 | print(g.idtype) |
| 123 | |
| 124 | hu = F.tensor( |
| 125 | np.random.rand(*((g.number_of_src_nodes(),) + shp[0])).astype(dtype) + 1 |
| 126 | ) |
| 127 | he = F.tensor( |
| 128 | np.random.rand(*((g.num_edges(),) + shp[1])).astype(dtype) + 1 |
| 129 | ) |
| 130 | print("u shape: {}, e shape: {}".format(F.shape(hu), F.shape(he))) |
| 131 | |
| 132 | g.srcdata["x"] = F.attach_grad(F.clone(hu)) |
| 133 | g.edata["w"] = F.attach_grad(F.clone(he)) |
| 134 | print("SpMM(message func: {}, reduce func: {})".format(msg, reducer)) |
| 135 | |
| 136 | u = F.attach_grad(F.clone(hu)) |
| 137 | e = F.attach_grad(F.clone(he)) |
| 138 | with F.record_grad(): |
| 139 | v = gspmm(g, msg, reducer, u, e) |
| 140 | if reducer in ["max", "min"]: |
| 141 | v = F.replace_inf_with_zero(v) |
| 142 | if g.num_edges() > 0: |
| 143 | F.backward(F.reduce_sum(v)) |
| 144 | if msg != "copy_rhs": |
| 145 | grad_u = F.grad(u) |
| 146 | if msg != "copy_lhs": |
| 147 | grad_e = F.grad(e) |
| 148 | |
| 149 | with F.record_grad(): |
| 150 | g.update_all(udf_msg[msg], udf_reduce[reducer]) |
| 151 | if g.num_edges() > 0: |
| 152 | v1 = g.dstdata["v"] |
| 153 | assert F.allclose(v, v1) |
| 154 | print("forward passed") |
| 155 | |
| 156 | F.backward(F.reduce_sum(v1)) |
| 157 | if msg != "copy_rhs": |
| 158 | if reducer in [ |
| 159 | "min", |
| 160 | "max", |
| 161 | ]: # there might be some numerical errors |
| 162 | rate = F.reduce_sum( |
| 163 | F.abs(F.grad(g.srcdata["x"]) - grad_u) |
| 164 | ) / F.reduce_sum(F.abs(grad_u)) |
| 165 | assert F.as_scalar(rate) < 1e-2, rate |
| 166 | else: |
| 167 | assert F.allclose(F.grad(g.srcdata["x"]), grad_u) |
| 168 | if msg != "copy_lhs": |
| 169 | if reducer in ["min", "max"]: |
| 170 | rate = F.reduce_sum( |
| 171 | F.abs(F.grad(g.edata["w"]) - grad_e) |
| 172 | ) / F.reduce_sum(F.abs(grad_e)) |
| 173 | assert F.as_scalar(rate) < 1e-2, rate |
| 174 | else: |
| 175 | assert F.allclose(F.grad(g.edata["w"]), grad_e) |
| 176 | print("backward passed") |
nothing calls this directly
no test coverage detected