MCPcopy
hub / github.com/dmlc/dgl / test_sparse_adam_uva

Function test_sparse_adam_uva

tests/python/pytorch/optim/test_optim.py:55–91  ·  view source on GitHub ↗
(use_uva, emb_dim)

Source from the content-addressed store, hash-verified

53@pytest.mark.parametrize("use_uva", [False, True, None])
54@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
55def test_sparse_adam_uva(use_uva, emb_dim):
56 if F.ctx().type == "cpu" and use_uva == True:
57 # we want to only test values of False and None when not using GPU
58 pytest.skip("UVA cannot be used without GPUs.")
59
60 num_embs = 10
61 device = F.ctx()
62 dgl_emb = NodeEmbedding(num_embs, emb_dim, "test_uva{}".format(use_uva))
63 torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
64 th.manual_seed(0)
65 th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
66 th.manual_seed(0)
67 th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
68
69 dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, use_uva=use_uva)
70 torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)
71
72 # first step
73 idx = th.randint(0, num_embs, size=(4,))
74 dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
75 torch_value = torch_emb(idx)
76 labels = th.zeros((4,)).long()
77
78 dgl_adam.zero_grad()
79 torch_adam.zero_grad()
80 dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
81 torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
82 dgl_loss.backward()
83 torch_loss.backward()
84
85 dgl_adam.step()
86 torch_adam.step()
87 assert F.allclose(dgl_emb.weight, torch_emb.weight)
88
89 # Can not test second step
90 # Pytorch sparseAdam maintains a global step
91 # DGL sparseAdam use a per embedding step
92
93
94@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")

Callers

nothing calls this directly

Calls 11

NodeEmbeddingClass · 0.90
SparseAdamClass · 0.90
formatMethod · 0.80
parametersMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45
deviceMethod · 0.45
longMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected