| 301 | |
| 302 | @pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"]) |
| 303 | def test_segment_reduce(reducer): |
| 304 | ctx = F.ctx() |
| 305 | value = F.tensor(np.random.rand(10, 5)) |
| 306 | v1 = F.attach_grad(F.clone(value)) |
| 307 | v2 = F.attach_grad(F.clone(value)) |
| 308 | seglen = F.tensor([2, 3, 0, 4, 1, 0, 0]) |
| 309 | u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx) |
| 310 | v = F.repeat( |
| 311 | F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0 |
| 312 | ) |
| 313 | |
| 314 | num_nodes = {"_U": len(u), "_V": len(seglen)} |
| 315 | g = dgl.convert.heterograph( |
| 316 | {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes |
| 317 | ) |
| 318 | with F.record_grad(): |
| 319 | rst1 = gspmm(g, "copy_lhs", reducer, v1, None) |
| 320 | if reducer in ["max", "min"]: |
| 321 | rst1 = F.replace_inf_with_zero(rst1) |
| 322 | F.backward(F.reduce_sum(rst1)) |
| 323 | grad1 = F.grad(v1) |
| 324 | |
| 325 | with F.record_grad(): |
| 326 | rst2 = segment_reduce(seglen, v2, reducer=reducer) |
| 327 | F.backward(F.reduce_sum(rst2)) |
| 328 | assert F.allclose(rst1, rst2) |
| 329 | print("forward passed") |
| 330 | |
| 331 | grad2 = F.grad(v2) |
| 332 | assert F.allclose(grad1, grad2) |
| 333 | print("backward passed") |
| 334 | |
| 335 | |
| 336 | @unittest.skipIf( |