(self, params)
| 733 | self._nd_handle[name] = [mem_nd, power_nd] |
| 734 | |
| 735 | def setup(self, params): |
| 736 | # We need to register a state sum for each embedding in the kvstore. |
| 737 | for emb in params: |
| 738 | assert isinstance( |
| 739 | emb, NodeEmbedding |
| 740 | ), "SparseAdam only supports dgl.nn.NodeEmbedding" |
| 741 | emb_name = emb.name |
| 742 | self._is_using_uva[emb_name] = self._use_uva |
| 743 | if th.device(emb.weight.device) == th.device("cpu"): |
| 744 | # if our embedding is on the CPU, our state also has to be |
| 745 | if self._rank < 0: |
| 746 | state_step = th.empty( |
| 747 | (emb.weight.shape[0],), |
| 748 | dtype=th.int32, |
| 749 | device=th.device("cpu"), |
| 750 | ).zero_() |
| 751 | state_mem = th.empty( |
| 752 | emb.weight.shape, |
| 753 | dtype=self._dtype, |
| 754 | device=th.device("cpu"), |
| 755 | ).zero_() |
| 756 | state_power = th.empty( |
| 757 | emb.weight.shape, |
| 758 | dtype=self._dtype, |
| 759 | device=th.device("cpu"), |
| 760 | ).zero_() |
| 761 | elif self._rank == 0: |
| 762 | state_step = create_shared_mem_array( |
| 763 | emb_name + "_step", (emb.weight.shape[0],), th.int32 |
| 764 | ).zero_() |
| 765 | state_mem = create_shared_mem_array( |
| 766 | emb_name + "_mem", emb.weight.shape, self._dtype |
| 767 | ).zero_() |
| 768 | state_power = create_shared_mem_array( |
| 769 | emb_name + "_power", emb.weight.shape, self._dtype |
| 770 | ).zero_() |
| 771 | |
| 772 | if self._world_size > 1: |
| 773 | emb.store.set(emb_name + "_opt", emb_name) |
| 774 | elif self._rank > 0: |
| 775 | # receive |
| 776 | emb.store.wait([emb_name + "_opt"]) |
| 777 | state_step = get_shared_mem_array( |
| 778 | emb_name + "_step", (emb.weight.shape[0],), th.int32 |
| 779 | ) |
| 780 | state_mem = get_shared_mem_array( |
| 781 | emb_name + "_mem", emb.weight.shape, self._dtype |
| 782 | ) |
| 783 | state_power = get_shared_mem_array( |
| 784 | emb_name + "_power", emb.weight.shape, self._dtype |
| 785 | ) |
| 786 | |
| 787 | if self._is_using_uva[emb_name]: |
| 788 | # if use_uva has been explicitly set to true, otherwise |
| 789 | # wait until first step to decide |
| 790 | self._setup_uva(emb_name, state_mem, state_power) |
| 791 | else: |
| 792 | # make sure we don't use UVA when data is on the GPU |
no test coverage detected