MCPcopy Index your code
hub / github.com/dmlc/dgl / update

Method update

python/dgl/distributed/optim/pytorch/sparse_optim.py:514–576  ·  view source on GitHub ↗

Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated.

(self, idx, grad, emb)

Source from the content-addressed store, hash-verified

512 self._state[emb.name] = state
513
514 def update(self, idx, grad, emb):
515 """Update embeddings in a sparse manner
516 Sparse embeddings are updated in mini batches. We maintain gradient states for
517 each embedding so they can be updated separately.
518
519 Parameters
520 ----------
521 idx : tensor
522 Index of the embeddings to be updated.
523 grad : tensor
524 Gradient of each embedding.
525 emb : dgl.distributed.DistEmbedding
526 Sparse embedding to update.
527 """
528 eps = self._eps
529 clr = self._lr
530
531 state_dev = th.device("cpu")
532 exec_dev = grad.device
533
534 # only perform async copies cpu -> gpu, or gpu-> gpu, but block
535 # when copying to the cpu, so as to ensure the copy is finished
536 # before operating on the data on the cpu
537 state_block = state_dev == th.device("cpu") and exec_dev != state_dev
538
539 # the update is non-linear so indices must be unique
540 grad_indices, inverse, cnt = th.unique(
541 idx, return_inverse=True, return_counts=True
542 )
543 grad_values = th.zeros(
544 (grad_indices.shape[0], grad.shape[1]), device=exec_dev
545 )
546 grad_values.index_add_(0, inverse, grad)
547 grad_values = grad_values / cnt.unsqueeze(1)
548 grad_sum = grad_values * grad_values
549
550 # update grad state
551 grad_state = self._state[emb.name][grad_indices].to(exec_dev)
552 grad_state += grad_sum
553 grad_state_dst = grad_state.to(state_dev, non_blocking=True)
554 if state_block:
555 # use events to try and overlap CPU and GPU as much as possible
556 update_event = th.cuda.Event()
557 update_event.record()
558
559 # update emb
560 std_values = grad_state.sqrt_().add_(eps)
561 tmp = clr * grad_values / std_values
562 tmp_dst = tmp.to(state_dev, non_blocking=True)
563
564 if state_block:
565 std_event = th.cuda.Event()
566 std_event.record()
567 # wait for our transfers from exec_dev to state_dev to finish
568 # before we can use them
569 update_event.wait()
570 self._state[emb.name][grad_indices] = grad_state_dst
571

Callers

nothing calls this directly

Calls 3

deviceMethod · 0.45
toMethod · 0.45
waitMethod · 0.45

Tested by

no test coverage detected