| 348 | ], |
| 349 | ) |
| 350 | def test_segment_mm(idtype, feat_size, dtype, tol): |
| 351 | if F._default_context_str == "cpu" and dtype == torch.float16: |
| 352 | pytest.skip("float16 is not supported on CPU.") |
| 353 | if ( |
| 354 | F._default_context_str == "gpu" |
| 355 | and dtype == torch.bfloat16 |
| 356 | and not torch.cuda.is_bf16_supported() |
| 357 | ): |
| 358 | pytest.skip("BF16 is not supported.") |
| 359 | dev = F.ctx() |
| 360 | # input |
| 361 | a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype) |
| 362 | a.requires_grad_() |
| 363 | b = ( |
| 364 | torch.tensor(np.random.rand(10, feat_size, feat_size + 1)) |
| 365 | .to(dev) |
| 366 | .to(dtype) |
| 367 | ) |
| 368 | b.requires_grad_() |
| 369 | seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype) |
| 370 | dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype) |
| 371 | # compute |
| 372 | c = dgl.ops.segment_mm(a, b, seglen_a) |
| 373 | c.backward(dc) |
| 374 | da = a.grad.clone() |
| 375 | db = b.grad.clone() |
| 376 | # ground truth |
| 377 | c_t = [] |
| 378 | off = 0 |
| 379 | for i, l in enumerate(seglen_a): |
| 380 | c_t.append(a[off : off + l] @ b[i]) |
| 381 | off += l |
| 382 | c_t = torch.cat(c_t).to(dtype) |
| 383 | a.grad.zero_() |
| 384 | b.grad.zero_() |
| 385 | c_t.backward(dc) |
| 386 | da_t = a.grad |
| 387 | db_t = b.grad |
| 388 | |
| 389 | assert torch.allclose(c, c_t, atol=tol, rtol=tol) |
| 390 | assert torch.allclose(da, da_t, atol=tol, rtol=tol) |
| 391 | assert torch.allclose(db, db_t, atol=tol, rtol=tol) |
| 392 | |
| 393 | |
| 394 | @unittest.skipIf( |