MCPcopy
hub / github.com/dmlc/dgl / setup

Method setup

python/dgl/optim/pytorch/sparse_optim.py:735–812  ·  view source on GitHub ↗
(self, params)

Source from the content-addressed store, hash-verified

733 self._nd_handle[name] = [mem_nd, power_nd]
734
735 def setup(self, params):
736 # We need to register a state sum for each embedding in the kvstore.
737 for emb in params:
738 assert isinstance(
739 emb, NodeEmbedding
740 ), "SparseAdam only supports dgl.nn.NodeEmbedding"
741 emb_name = emb.name
742 self._is_using_uva[emb_name] = self._use_uva
743 if th.device(emb.weight.device) == th.device("cpu"):
744 # if our embedding is on the CPU, our state also has to be
745 if self._rank < 0:
746 state_step = th.empty(
747 (emb.weight.shape[0],),
748 dtype=th.int32,
749 device=th.device("cpu"),
750 ).zero_()
751 state_mem = th.empty(
752 emb.weight.shape,
753 dtype=self._dtype,
754 device=th.device("cpu"),
755 ).zero_()
756 state_power = th.empty(
757 emb.weight.shape,
758 dtype=self._dtype,
759 device=th.device("cpu"),
760 ).zero_()
761 elif self._rank == 0:
762 state_step = create_shared_mem_array(
763 emb_name + "_step", (emb.weight.shape[0],), th.int32
764 ).zero_()
765 state_mem = create_shared_mem_array(
766 emb_name + "_mem", emb.weight.shape, self._dtype
767 ).zero_()
768 state_power = create_shared_mem_array(
769 emb_name + "_power", emb.weight.shape, self._dtype
770 ).zero_()
771
772 if self._world_size > 1:
773 emb.store.set(emb_name + "_opt", emb_name)
774 elif self._rank > 0:
775 # receive
776 emb.store.wait([emb_name + "_opt"])
777 state_step = get_shared_mem_array(
778 emb_name + "_step", (emb.weight.shape[0],), th.int32
779 )
780 state_mem = get_shared_mem_array(
781 emb_name + "_mem", emb.weight.shape, self._dtype
782 )
783 state_power = get_shared_mem_array(
784 emb_name + "_power", emb.weight.shape, self._dtype
785 )
786
787 if self._is_using_uva[emb_name]:
788 # if use_uva has been explicitly set to true, otherwise
789 # wait until first step to decide
790 self._setup_uva(emb_name, state_mem, state_power)
791 else:
792 # make sure we don't use UVA when data is on the GPU

Callers 1

__init__Method · 0.95

Calls 6

_setup_uvaMethod · 0.95
create_shared_mem_arrayFunction · 0.85
get_shared_mem_arrayFunction · 0.85
set_optm_stateMethod · 0.80
deviceMethod · 0.45
waitMethod · 0.45

Tested by

no test coverage detected