MCPcopy
hub / github.com/dmlc/dgl / test_segment_mm

Function test_segment_mm

tests/python/common/ops/test_ops.py:350–391  ·  view source on GitHub ↗
(idtype, feat_size, dtype, tol)

Source from the content-addressed store, hash-verified

348 ],
349)
350def test_segment_mm(idtype, feat_size, dtype, tol):
351 if F._default_context_str == "cpu" and dtype == torch.float16:
352 pytest.skip("float16 is not supported on CPU.")
353 if (
354 F._default_context_str == "gpu"
355 and dtype == torch.bfloat16
356 and not torch.cuda.is_bf16_supported()
357 ):
358 pytest.skip("BF16 is not supported.")
359 dev = F.ctx()
360 # input
361 a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
362 a.requires_grad_()
363 b = (
364 torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
365 .to(dev)
366 .to(dtype)
367 )
368 b.requires_grad_()
369 seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
370 dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
371 # compute
372 c = dgl.ops.segment_mm(a, b, seglen_a)
373 c.backward(dc)
374 da = a.grad.clone()
375 db = b.grad.clone()
376 # ground truth
377 c_t = []
378 off = 0
379 for i, l in enumerate(seglen_a):
380 c_t.append(a[off : off + l] @ b[i])
381 off += l
382 c_t = torch.cat(c_t).to(dtype)
383 a.grad.zero_()
384 b.grad.zero_()
385 c_t.backward(dc)
386 da_t = a.grad
387 db_t = b.grad
388
389 assert torch.allclose(c, c_t, atol=tol, rtol=tol)
390 assert torch.allclose(da, da_t, atol=tol, rtol=tol)
391 assert torch.allclose(db, db_t, atol=tol, rtol=tol)
392
393
394@unittest.skipIf(

Callers

nothing calls this directly

Calls 5

appendMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45
backwardMethod · 0.45
cloneMethod · 0.45

Tested by

no test coverage detected