()
| 131 | |
| 132 | @unittest.skipIf(os.name == "nt", reason="Do not support windows yet") |
| 133 | def test_sparse_adam_zero_step(): |
| 134 | num_embs = 10 |
| 135 | emb_dim = 4 |
| 136 | device = F.ctx() |
| 137 | dgl_emb = NodeEmbedding(num_embs, emb_dim, "test") |
| 138 | torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True) |
| 139 | dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, "test2") |
| 140 | torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True) |
| 141 | th.manual_seed(0) |
| 142 | th.nn.init.uniform_(torch_emb.weight, 0, 1.0) |
| 143 | th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0) |
| 144 | th.manual_seed(0) |
| 145 | th.nn.init.uniform_(dgl_emb.weight, 0, 1.0) |
| 146 | th.nn.init.uniform_(dgl_emb_zero.weight, 0, 1.0) |
| 147 | |
| 148 | dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01) |
| 149 | torch_adam = th.optim.SparseAdam( |
| 150 | list(torch_emb.parameters()) + list(torch_emb_zero.parameters()), |
| 151 | lr=0.01, |
| 152 | ) |
| 153 | |
| 154 | # first step |
| 155 | idx = th.randint(0, num_embs, size=(4,)) |
| 156 | dgl_value = dgl_emb(idx, device).to(th.device("cpu")) |
| 157 | torch_value = torch_emb(idx) |
| 158 | labels = th.ones((4,)).long() |
| 159 | |
| 160 | dgl_adam.zero_grad() |
| 161 | torch_adam.zero_grad() |
| 162 | dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels) |
| 163 | torch_loss = th.nn.functional.cross_entropy(torch_value, labels) |
| 164 | dgl_loss.backward() |
| 165 | torch_loss.backward() |
| 166 | |
| 167 | dgl_adam.step() |
| 168 | torch_adam.step() |
| 169 | assert F.allclose(dgl_emb.weight, torch_emb.weight) |
| 170 | |
| 171 | |
| 172 | def initializer(emb): |
no test coverage detected