MCPcopy
hub / github.com/dmlc/dgl / update

Method update

python/dgl/distributed/optim/pytorch/sparse_optim.py:658–747  ·  view source on GitHub ↗

Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated.

(self, idx, grad, emb)

Source from the content-addressed store, hash-verified

656 self._state[emb.name] = state
657
658 def update(self, idx, grad, emb):
659 """Update embeddings in a sparse manner
660 Sparse embeddings are updated in mini batches. We maintain gradient states for
661 each embedding so they can be updated separately.
662
663 Parameters
664 ----------
665 idx : tensor
666 Index of the embeddings to be updated.
667 grad : tensor
668 Gradient of each embedding.
669 emb : dgl.distributed.DistEmbedding
670 Sparse embedding to update.
671 """
672 beta1 = self._beta1
673 beta2 = self._beta2
674 eps = self._eps
675 clr = self._lr
676 state_step, state_mem, state_power = self._state[emb.name]
677
678 state_dev = th.device("cpu")
679 exec_dev = grad.device
680
681 # only perform async copies cpu -> gpu, or gpu-> gpu, but block
682 # when copying to the cpu, so as to ensure the copy is finished
683 # before operating on the data on the cpu
684 state_block = state_dev == th.device("cpu") and exec_dev != state_dev
685
686 # the update is non-linear so indices must be unique
687 grad_indices, inverse, cnt = th.unique(
688 idx, return_inverse=True, return_counts=True
689 )
690 # update grad state
691 state_idx = grad_indices.to(state_dev)
692 # The original implementation will cause read/write contension.
693 # state_step[state_idx] += 1
694 # state_step = state_step[state_idx].to(exec_dev, non_blocking=True)
695 # In a distributed environment, the first line of code will send write requests to
696 # kvstore servers to update the state_step which is asynchronous and the second line
697 # of code will also send read requests to kvstore servers. The write and read requests
698 # may be handled by different kvstore servers managing the same portion of the
699 # state_step dist tensor in the same node. So that, the read request may read an old
700 # value (i.e., 0 in the first iteration) which will cause
701 # update_power_corr to be NaN
702 state_val = state_step[state_idx] + 1
703 state_step[state_idx] = state_val
704 state_step = state_val.to(exec_dev)
705 orig_mem = state_mem[state_idx].to(exec_dev)
706 orig_power = state_power[state_idx].to(exec_dev)
707
708 grad_values = th.zeros(
709 (grad_indices.shape[0], grad.shape[1]), device=exec_dev
710 )
711 grad_values.index_add_(0, inverse, grad)
712 grad_values = grad_values / cnt.unsqueeze(1)
713 grad_mem = grad_values
714 grad_power = grad_values * grad_values
715 update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem

Callers 10

_set_lazy_featuresFunction · 0.45
local_state_dictMethod · 0.45
load_local_state_dictMethod · 0.45
_get_hashMethod · 0.45
_get_hash_url_suffixMethod · 0.45
_md5sumMethod · 0.45
_md5sumMethod · 0.45
downloadFunction · 0.45
check_sha1Function · 0.45

Calls 3

deviceMethod · 0.45
toMethod · 0.45
waitMethod · 0.45

Tested by

no test coverage detected