| 3396 | |
| 3397 | @parametrize_idtype |
| 3398 | def test_module_svd_pe(idtype): |
| 3399 | g = dgl.graph( |
| 3400 | ( |
| 3401 | [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4], |
| 3402 | [2, 3, 0, 2, 0, 2, 3, 4, 3, 4, 0, 1], |
| 3403 | ), |
| 3404 | idtype=idtype, |
| 3405 | device=F.ctx(), |
| 3406 | ) |
| 3407 | # without padding |
| 3408 | tgt_pe = F.copy_to( |
| 3409 | F.tensor( |
| 3410 | [ |
| 3411 | [0.6669, 0.3068, 0.7979, 0.8477], |
| 3412 | [0.6311, 0.6101, 0.1248, 0.5137], |
| 3413 | [1.1993, 0.0665, 0.9183, 0.1455], |
| 3414 | [0.5682, 0.6766, 0.8952, 0.6449], |
| 3415 | [0.3393, 0.8363, 0.6500, 0.4564], |
| 3416 | ] |
| 3417 | ), |
| 3418 | g.device, |
| 3419 | ) |
| 3420 | transform_1 = dgl.SVDPE(k=2, feat_name="svd_pe") |
| 3421 | g1 = transform_1(g) |
| 3422 | if dgl.backend.backend_name == "tensorflow": |
| 3423 | assert F.allclose(g1.ndata["svd_pe"].__abs__(), tgt_pe) |
| 3424 | else: |
| 3425 | assert F.allclose(g1.ndata["svd_pe"].abs(), tgt_pe) |
| 3426 | |
| 3427 | # with padding |
| 3428 | transform_2 = dgl.SVDPE(k=6, feat_name="svd_pe", padding=True) |
| 3429 | g2 = transform_2(g) |
| 3430 | assert F.shape(g2.ndata["svd_pe"]) == (5, 12) |
| 3431 | |
| 3432 | |
| 3433 | if __name__ == "__main__": |