(g, shp, lhs_target, rhs_target, msg, idtype)
| 232 | ) |
| 233 | @parametrize_idtype |
| 234 | def test_sddmm(g, shp, lhs_target, rhs_target, msg, idtype): |
| 235 | if lhs_target == rhs_target: |
| 236 | return |
| 237 | g = g.astype(idtype).to(F.ctx()) |
| 238 | if dgl.backend.backend_name == "mxnet" and g.num_edges() == 0: |
| 239 | pytest.skip() # mxnet do not support zero shape tensor |
| 240 | print(g) |
| 241 | print(g.idtype) |
| 242 | |
| 243 | len_lhs = select( |
| 244 | lhs_target, |
| 245 | g.number_of_src_nodes(), |
| 246 | g.num_edges(), |
| 247 | g.number_of_dst_nodes(), |
| 248 | ) |
| 249 | lhs_shp = (len_lhs,) + shp[0] |
| 250 | len_rhs = select( |
| 251 | rhs_target, |
| 252 | g.number_of_src_nodes(), |
| 253 | g.num_edges(), |
| 254 | g.number_of_dst_nodes(), |
| 255 | ) |
| 256 | rhs_shp = (len_rhs,) + shp[1] |
| 257 | feat_lhs = F.tensor(np.random.rand(*lhs_shp) + 1) |
| 258 | feat_rhs = F.tensor(np.random.rand(*rhs_shp) + 1) |
| 259 | print( |
| 260 | "lhs shape: {}, rhs shape: {}".format( |
| 261 | F.shape(feat_lhs), F.shape(feat_rhs) |
| 262 | ) |
| 263 | ) |
| 264 | |
| 265 | lhs_frame = select(lhs_target, g.srcdata, g.edata, g.dstdata) |
| 266 | rhs_frame = select(rhs_target, g.srcdata, g.edata, g.dstdata) |
| 267 | lhs_frame["x"] = F.attach_grad(F.clone(feat_lhs)) |
| 268 | rhs_frame["y"] = F.attach_grad(F.clone(feat_rhs)) |
| 269 | msg_func = lhs_target + "_" + msg + "_" + rhs_target |
| 270 | print("SDDMM(message func: {})".format(msg_func)) |
| 271 | |
| 272 | lhs = F.attach_grad(F.clone(feat_lhs)) |
| 273 | rhs = F.attach_grad(F.clone(feat_rhs)) |
| 274 | with F.record_grad(): |
| 275 | e = gsddmm( |
| 276 | g, msg, lhs, rhs, lhs_target=lhs_target, rhs_target=rhs_target |
| 277 | ) |
| 278 | F.backward(F.reduce_sum(e)) |
| 279 | grad_lhs = F.grad(lhs) |
| 280 | grad_rhs = F.grad(rhs) |
| 281 | |
| 282 | with F.record_grad(): |
| 283 | g.apply_edges(udf_apply_edges[msg_func]) |
| 284 | if g.num_edges() > 0: |
| 285 | e1 = g.edata["m"] |
| 286 | assert F.allclose(e, e1) |
| 287 | print("forward passed") |
| 288 | |
| 289 | F.backward(F.reduce_sum(e1)) |
| 290 | if msg != "copy_rhs": |
| 291 | assert F.allclose(F.grad(lhs_frame["x"]), grad_lhs) |
nothing calls this directly
no test coverage detected