Update embeddings in a sparse manner Sparse embeddings are updated in mini batches. We maintain gradient states for each embedding so they can be updated separately. Parameters ---------- idx : tensor Index of the embeddings to be updated.
(self, idx, grad, emb)
| 512 | self._state[emb.name] = state |
| 513 | |
| 514 | def update(self, idx, grad, emb): |
| 515 | """Update embeddings in a sparse manner |
| 516 | Sparse embeddings are updated in mini batches. We maintain gradient states for |
| 517 | each embedding so they can be updated separately. |
| 518 | |
| 519 | Parameters |
| 520 | ---------- |
| 521 | idx : tensor |
| 522 | Index of the embeddings to be updated. |
| 523 | grad : tensor |
| 524 | Gradient of each embedding. |
| 525 | emb : dgl.distributed.DistEmbedding |
| 526 | Sparse embedding to update. |
| 527 | """ |
| 528 | eps = self._eps |
| 529 | clr = self._lr |
| 530 | |
| 531 | state_dev = th.device("cpu") |
| 532 | exec_dev = grad.device |
| 533 | |
| 534 | # only perform async copies cpu -> gpu, or gpu-> gpu, but block |
| 535 | # when copying to the cpu, so as to ensure the copy is finished |
| 536 | # before operating on the data on the cpu |
| 537 | state_block = state_dev == th.device("cpu") and exec_dev != state_dev |
| 538 | |
| 539 | # the update is non-linear so indices must be unique |
| 540 | grad_indices, inverse, cnt = th.unique( |
| 541 | idx, return_inverse=True, return_counts=True |
| 542 | ) |
| 543 | grad_values = th.zeros( |
| 544 | (grad_indices.shape[0], grad.shape[1]), device=exec_dev |
| 545 | ) |
| 546 | grad_values.index_add_(0, inverse, grad) |
| 547 | grad_values = grad_values / cnt.unsqueeze(1) |
| 548 | grad_sum = grad_values * grad_values |
| 549 | |
| 550 | # update grad state |
| 551 | grad_state = self._state[emb.name][grad_indices].to(exec_dev) |
| 552 | grad_state += grad_sum |
| 553 | grad_state_dst = grad_state.to(state_dev, non_blocking=True) |
| 554 | if state_block: |
| 555 | # use events to try and overlap CPU and GPU as much as possible |
| 556 | update_event = th.cuda.Event() |
| 557 | update_event.record() |
| 558 | |
| 559 | # update emb |
| 560 | std_values = grad_state.sqrt_().add_(eps) |
| 561 | tmp = clr * grad_values / std_values |
| 562 | tmp_dst = tmp.to(state_dev, non_blocking=True) |
| 563 | |
| 564 | if state_block: |
| 565 | std_event = th.cuda.Event() |
| 566 | std_event.record() |
| 567 | # wait for our transfers from exec_dev to state_dev to finish |
| 568 | # before we can use them |
| 569 | update_event.wait() |
| 570 | self._state[emb.name][grad_indices] = grad_state_dst |
| 571 |