Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated.
(self, idx, grad, emb)
| 812 | emb.set_optm_state(state) |
| 813 | |
| 814 | def update(self, idx, grad, emb): |
| 815 | """Update embeddings in a sparse manner |
| 816 | Sparse embeddings are updated in mini batches. We maintain gradient states for |
| 817 | each embedding so they can be updated separately. |
| 818 | |
| 819 | Parameters |
| 820 | ---------- |
| 821 | idx : tensor |
| 822 | Index of the embeddings to be updated. |
| 823 | grad : tensor |
| 824 | Gradient of each embedding. |
| 825 | emb : dgl.nn.NodeEmbedding |
| 826 | Sparse embedding to update. |
| 827 | """ |
| 828 | with th.no_grad(): |
| 829 | state_step, state_mem, state_power = emb.optm_state |
| 830 | exec_dtype = grad.dtype |
| 831 | exec_dev = grad.device |
| 832 | state_dev = state_step.device |
| 833 | |
| 834 | # whether or not we need to transfer data from the GPU to the CPU |
| 835 | # while updating the weights |
| 836 | is_d2h = state_dev.type == "cpu" and exec_dev.type == "cuda" |
| 837 | |
| 838 | # only perform async copies cpu -> gpu, or gpu-> gpu, but block |
| 839 | # when copying to the cpu, so as to ensure the copy is finished |
| 840 | # before operating on the data on the cpu |
| 841 | state_block = is_d2h |
| 842 | |
| 843 | if self._is_using_uva[emb.name] is None and is_d2h: |
| 844 | # we should use UVA going forward |
| 845 | self._setup_uva(emb.name, state_mem, state_power) |
| 846 | elif self._is_using_uva[emb.name] is None: |
| 847 | # we shouldn't use UVA going forward |
| 848 | self._is_using_uva[emb.name] = False |
| 849 | |
| 850 | use_uva = self._is_using_uva[emb.name] |
| 851 | |
| 852 | beta1 = self._beta1 |
| 853 | beta2 = self._beta2 |
| 854 | eps = self._eps |
| 855 | |
| 856 | clr = self._lr |
| 857 | # There can be duplicated indices due to sampling. |
| 858 | # Thus unique them here and average the gradient here. |
| 859 | grad_indices, inverse, cnt = th.unique( |
| 860 | idx, return_inverse=True, return_counts=True |
| 861 | ) |
| 862 | state_idx = grad_indices.to(state_dev) |
| 863 | state_step[state_idx] += 1 |
| 864 | state_step = state_step[state_idx].to(exec_dev) |
| 865 | |
| 866 | if use_uva: |
| 867 | orig_mem = gather_pinned_tensor_rows(state_mem, grad_indices) |
| 868 | orig_power = gather_pinned_tensor_rows( |
| 869 | state_power, grad_indices |
| 870 | ) |
| 871 | else: |
nothing calls this directly
no test coverage detected