MCPcopy
hub / github.com/dmlc/dgl / update

Method update

python/dgl/optim/pytorch/sparse_optim.py:814–938  ·  view source on GitHub ↗

Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated.

(self, idx, grad, emb)

Source from the content-addressed store, hash-verified

812 emb.set_optm_state(state)
813
814 def update(self, idx, grad, emb):
815 """Update embeddings in a sparse manner
816 Sparse embeddings are updated in mini batches. We maintain gradient states for
817 each embedding so they can be updated separately.
818
819 Parameters
820 ----------
821 idx : tensor
822 Index of the embeddings to be updated.
823 grad : tensor
824 Gradient of each embedding.
825 emb : dgl.nn.NodeEmbedding
826 Sparse embedding to update.
827 """
828 with th.no_grad():
829 state_step, state_mem, state_power = emb.optm_state
830 exec_dtype = grad.dtype
831 exec_dev = grad.device
832 state_dev = state_step.device
833
834 # whether or not we need to transfer data from the GPU to the CPU
835 # while updating the weights
836 is_d2h = state_dev.type == "cpu" and exec_dev.type == "cuda"
837
838 # only perform async copies cpu -> gpu, or gpu-> gpu, but block
839 # when copying to the cpu, so as to ensure the copy is finished
840 # before operating on the data on the cpu
841 state_block = is_d2h
842
843 if self._is_using_uva[emb.name] is None and is_d2h:
844 # we should use UVA going forward
845 self._setup_uva(emb.name, state_mem, state_power)
846 elif self._is_using_uva[emb.name] is None:
847 # we shouldn't use UVA going forward
848 self._is_using_uva[emb.name] = False
849
850 use_uva = self._is_using_uva[emb.name]
851
852 beta1 = self._beta1
853 beta2 = self._beta2
854 eps = self._eps
855
856 clr = self._lr
857 # There can be duplicated indices due to sampling.
858 # Thus unique them here and average the gradient here.
859 grad_indices, inverse, cnt = th.unique(
860 idx, return_inverse=True, return_counts=True
861 )
862 state_idx = grad_indices.to(state_dev)
863 state_step[state_idx] += 1
864 state_step = state_step[state_idx].to(exec_dev)
865
866 if use_uva:
867 orig_mem = gather_pinned_tensor_rows(state_mem, grad_indices)
868 orig_power = gather_pinned_tensor_rows(
869 state_power, grad_indices
870 )
871 else:

Callers

nothing calls this directly

Calls 5

_setup_uvaMethod · 0.95
toMethod · 0.45
waitMethod · 0.45

Tested by

no test coverage detected