MCPcopy
hub / github.com/dmlc/dgl / test_segment_reduce

Function test_segment_reduce

tests/python/common/ops/test_ops.py:303–333  ·  view source on GitHub ↗
(reducer)

Source from the content-addressed store, hash-verified

301
302@pytest.mark.parametrize("reducer", ["sum", "max", "min", "mean"])
303def test_segment_reduce(reducer):
304 ctx = F.ctx()
305 value = F.tensor(np.random.rand(10, 5))
306 v1 = F.attach_grad(F.clone(value))
307 v2 = F.attach_grad(F.clone(value))
308 seglen = F.tensor([2, 3, 0, 4, 1, 0, 0])
309 u = F.copy_to(F.arange(0, F.shape(value)[0], F.int32), ctx)
310 v = F.repeat(
311 F.copy_to(F.arange(0, len(seglen), F.int32), ctx), seglen, dim=0
312 )
313
314 num_nodes = {"_U": len(u), "_V": len(seglen)}
315 g = dgl.convert.heterograph(
316 {("_U", "_E", "_V"): (u, v)}, num_nodes_dict=num_nodes
317 )
318 with F.record_grad():
319 rst1 = gspmm(g, "copy_lhs", reducer, v1, None)
320 if reducer in ["max", "min"]:
321 rst1 = F.replace_inf_with_zero(rst1)
322 F.backward(F.reduce_sum(rst1))
323 grad1 = F.grad(v1)
324
325 with F.record_grad():
326 rst2 = segment_reduce(seglen, v2, reducer=reducer)
327 F.backward(F.reduce_sum(rst2))
328 assert F.allclose(rst1, rst2)
329 print("forward passed")
330
331 grad2 = F.grad(v2)
332 assert F.allclose(grad1, grad2)
333 print("backward passed")
334
335
336@unittest.skipIf(

Callers

nothing calls this directly

Calls 8

gspmmFunction · 0.90
segment_reduceFunction · 0.90
gradMethod · 0.80
ctxMethod · 0.45
cloneMethod · 0.45
copy_toMethod · 0.45
shapeMethod · 0.45
backwardMethod · 0.45

Tested by

no test coverage detected