Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated.
(self, idx, grad, emb)
| 656 | self._state[emb.name] = state |
| 657 | |
| 658 | def update(self, idx, grad, emb): |
| 659 | """Update embeddings in a sparse manner |
| 660 | Sparse embeddings are updated in mini batches. We maintain gradient states for |
| 661 | each embedding so they can be updated separately. |
| 662 | |
| 663 | Parameters |
| 664 | ---------- |
| 665 | idx : tensor |
| 666 | Index of the embeddings to be updated. |
| 667 | grad : tensor |
| 668 | Gradient of each embedding. |
| 669 | emb : dgl.distributed.DistEmbedding |
| 670 | Sparse embedding to update. |
| 671 | """ |
| 672 | beta1 = self._beta1 |
| 673 | beta2 = self._beta2 |
| 674 | eps = self._eps |
| 675 | clr = self._lr |
| 676 | state_step, state_mem, state_power = self._state[emb.name] |
| 677 | |
| 678 | state_dev = th.device("cpu") |
| 679 | exec_dev = grad.device |
| 680 | |
| 681 | # only perform async copies cpu -> gpu, or gpu-> gpu, but block |
| 682 | # when copying to the cpu, so as to ensure the copy is finished |
| 683 | # before operating on the data on the cpu |
| 684 | state_block = state_dev == th.device("cpu") and exec_dev != state_dev |
| 685 | |
| 686 | # the update is non-linear so indices must be unique |
| 687 | grad_indices, inverse, cnt = th.unique( |
| 688 | idx, return_inverse=True, return_counts=True |
| 689 | ) |
| 690 | # update grad state |
| 691 | state_idx = grad_indices.to(state_dev) |
| 692 | # The original implementation will cause read/write contension. |
| 693 | # state_step[state_idx] += 1 |
| 694 | # state_step = state_step[state_idx].to(exec_dev, non_blocking=True) |
| 695 | # In a distributed environment, the first line of code will send write requests to |
| 696 | # kvstore servers to update the state_step which is asynchronous and the second line |
| 697 | # of code will also send read requests to kvstore servers. The write and read requests |
| 698 | # may be handled by different kvstore servers managing the same portion of the |
| 699 | # state_step dist tensor in the same node. So that, the read request may read an old |
| 700 | # value (i.e., 0 in the first iteration) which will cause |
| 701 | # update_power_corr to be NaN |
| 702 | state_val = state_step[state_idx] + 1 |
| 703 | state_step[state_idx] = state_val |
| 704 | state_step = state_val.to(exec_dev) |
| 705 | orig_mem = state_mem[state_idx].to(exec_dev) |
| 706 | orig_power = state_power[state_idx].to(exec_dev) |
| 707 | |
| 708 | grad_values = th.zeros( |
| 709 | (grad_indices.shape[0], grad.shape[1]), device=exec_dev |
| 710 | ) |
| 711 | grad_values.index_add_(0, inverse, grad) |
| 712 | grad_values = grad_values / cnt.unsqueeze(1) |
| 713 | grad_mem = grad_values |
| 714 | grad_power = grad_values * grad_values |
| 715 | update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem |
no test coverage detected