MCPcopy
hub / github.com/dmlc/dgl / test_sparse_adam

Function test_sparse_adam

tests/python/pytorch/optim/test_optim.py:15–49  ·  view source on GitHub ↗
(emb_dim)

Source from the content-addressed store, hash-verified

13@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
14@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
15def test_sparse_adam(emb_dim):
16 num_embs = 10
17 device = F.ctx()
18 dgl_emb = NodeEmbedding(num_embs, emb_dim, "test")
19 torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
20 th.manual_seed(0)
21 th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
22 th.manual_seed(0)
23 th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
24
25 dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
26 torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)
27
28 # first step
29 idx = th.randint(0, num_embs, size=(4,))
30 dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
31 torch_value = torch_emb(idx)
32 labels = th.zeros((4,)).long()
33 print("dgl_value = {}".format(dgl_value))
34 print("labels = {}".format(labels))
35
36 dgl_adam.zero_grad()
37 torch_adam.zero_grad()
38 dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
39 torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
40 dgl_loss.backward()
41 torch_loss.backward()
42
43 dgl_adam.step()
44 torch_adam.step()
45 assert F.allclose(dgl_emb.weight, torch_emb.weight)
46
47 # Can not test second step
48 # Pytorch sparseAdam maintains a global step
49 # DGL sparseAdam use a per embedding step
50
51
52@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")

Callers 1

test_optim.pyFile · 0.85

Calls 11

NodeEmbeddingClass · 0.90
SparseAdamClass · 0.90
parametersMethod · 0.80
formatMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45
deviceMethod · 0.45
longMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected