MCPcopy Index your code
hub / github.com/dmlc/dgl / start_sparse_adam_worker

Function start_sparse_adam_worker

tests/python/pytorch/optim/test_optim.py:178–242  ·  view source on GitHub ↗
(
    rank,
    device,
    world_size,
    weight,
    tensor_dev="cpu",
    has_zero_grad=False,
    backend="gloo",
    num_embs=128,
    emb_dim=10,
    zero_comm=True,
)

Source from the content-addressed store, hash-verified

176
177
178def start_sparse_adam_worker(
179 rank,
180 device,
181 world_size,
182 weight,
183 tensor_dev="cpu",
184 has_zero_grad=False,
185 backend="gloo",
186 num_embs=128,
187 emb_dim=10,
188 zero_comm=True,
189):
190 print("start sparse worker for adam {}".format(rank))
191 dist_init_method = "tcp://{master_ip}:{master_port}".format(
192 master_ip="127.0.0.1", master_port="12345"
193 )
194
195 if device.type == "cuda":
196 th.cuda.set_device(device)
197
198 th.distributed.init_process_group(
199 backend=backend,
200 init_method=dist_init_method,
201 world_size=world_size,
202 rank=rank,
203 )
204
205 init_weight = th.empty((num_embs, emb_dim))
206 th.manual_seed(0)
207 th.nn.init.uniform_(init_weight, -1.0, 1.0)
208 dgl_emb = NodeEmbedding(
209 num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev
210 )
211 dgl_emb.all_set_embedding(init_weight)
212
213 if has_zero_grad:
214 dgl_emb_zero = NodeEmbedding(
215 num_embs, emb_dim, "zero", init_func=initializer, device=tensor_dev
216 )
217 dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
218 else:
219 dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
220
221 th.manual_seed(rank)
222 if zero_comm:
223 start = (num_embs // world_size) * rank
224 end = (num_embs // world_size) * (rank + 1)
225 idx = th.randint(start, end, size=(4,)).to(tensor_dev)
226 else:
227 idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev)
228 dgl_value = dgl_emb(idx, device)
229 labels = th.ones((4,)).long().to(device)
230 dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
231 dgl_adam.zero_grad()
232 dgl_loss.backward()
233 dgl_adam.step()
234 th.distributed.barrier()
235 dgl_weight = dgl_emb.all_get_embedding().detach()

Callers

nothing calls this directly

Calls 13

all_set_embeddingMethod · 0.95
all_get_embeddingMethod · 0.95
NodeEmbeddingClass · 0.90
SparseAdamClass · 0.90
formatMethod · 0.80
set_deviceMethod · 0.45
toMethod · 0.45
longMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45
barrierMethod · 0.45

Tested by

no test coverage detected