MCPcopy
hub / github.com/dmlc/dgl / test_sparse_adam_zero_step

Function test_sparse_adam_zero_step

tests/python/pytorch/optim/test_optim.py:133–169  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

131
132@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
133def test_sparse_adam_zero_step():
134 num_embs = 10
135 emb_dim = 4
136 device = F.ctx()
137 dgl_emb = NodeEmbedding(num_embs, emb_dim, "test")
138 torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
139 dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, "test2")
140 torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
141 th.manual_seed(0)
142 th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
143 th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
144 th.manual_seed(0)
145 th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
146 th.nn.init.uniform_(dgl_emb_zero.weight, 0, 1.0)
147
148 dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
149 torch_adam = th.optim.SparseAdam(
150 list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),
151 lr=0.01,
152 )
153
154 # first step
155 idx = th.randint(0, num_embs, size=(4,))
156 dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
157 torch_value = torch_emb(idx)
158 labels = th.ones((4,)).long()
159
160 dgl_adam.zero_grad()
161 torch_adam.zero_grad()
162 dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
163 torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
164 dgl_loss.backward()
165 torch_loss.backward()
166
167 dgl_adam.step()
168 torch_adam.step()
169 assert F.allclose(dgl_emb.weight, torch_emb.weight)
170
171
172def initializer(emb):

Callers 1

test_optim.pyFile · 0.85

Calls 10

NodeEmbeddingClass · 0.90
SparseAdamClass · 0.90
parametersMethod · 0.80
ctxMethod · 0.45
toMethod · 0.45
deviceMethod · 0.45
longMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected