(emb_dim)
| 13 | @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") |
| 14 | @pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024]) |
| 15 | def test_sparse_adam(emb_dim): |
| 16 | num_embs = 10 |
| 17 | device = F.ctx() |
| 18 | dgl_emb = NodeEmbedding(num_embs, emb_dim, "test") |
| 19 | torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True) |
| 20 | th.manual_seed(0) |
| 21 | th.nn.init.uniform_(torch_emb.weight, 0, 1.0) |
| 22 | th.manual_seed(0) |
| 23 | th.nn.init.uniform_(dgl_emb.weight, 0, 1.0) |
| 24 | |
| 25 | dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01) |
| 26 | torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01) |
| 27 | |
| 28 | # first step |
| 29 | idx = th.randint(0, num_embs, size=(4,)) |
| 30 | dgl_value = dgl_emb(idx, device).to(th.device("cpu")) |
| 31 | torch_value = torch_emb(idx) |
| 32 | labels = th.zeros((4,)).long() |
| 33 | print("dgl_value = {}".format(dgl_value)) |
| 34 | print("labels = {}".format(labels)) |
| 35 | |
| 36 | dgl_adam.zero_grad() |
| 37 | torch_adam.zero_grad() |
| 38 | dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels) |
| 39 | torch_loss = th.nn.functional.cross_entropy(torch_value, labels) |
| 40 | dgl_loss.backward() |
| 41 | torch_loss.backward() |
| 42 | |
| 43 | dgl_adam.step() |
| 44 | torch_adam.step() |
| 45 | assert F.allclose(dgl_emb.weight, torch_emb.weight) |
| 46 | |
| 47 | # Can not test second step |
| 48 | # Pytorch sparseAdam maintains a global step |
| 49 | # DGL sparseAdam use a per embedding step |
| 50 | |
| 51 | |
| 52 | @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") |
no test coverage detected