MCPcopy
hub / github.com/dmlc/dgl / setup

Method setup

python/dgl/optim/pytorch/sparse_optim.py:555–591  ·  view source on GitHub ↗
(self, params)

Source from the content-addressed store, hash-verified

553 self.setup(self._params)
554
555 def setup(self, params):
556 # We need to register a state sum for each embedding in the kvstore.
557 for emb in params:
558 assert isinstance(
559 emb, NodeEmbedding
560 ), "SparseAdagrad only supports dgl.nn.NodeEmbedding"
561
562 emb_name = emb.name
563 if th.device(emb.weight.device) == th.device("cpu"):
564 # if our embedding is on the CPU, our state also has to be
565 if self._rank < 0:
566 state = th.empty(
567 emb.weight.shape,
568 dtype=th.float32,
569 device=th.device("cpu"),
570 ).zero_()
571 elif self._rank == 0:
572 state = create_shared_mem_array(
573 emb_name + "_state", emb.weight.shape, th.float32
574 ).zero_()
575
576 if self._world_size > 1:
577 emb.store.set(emb_name + "_opt", emb_name)
578 elif self._rank > 0:
579 # receive
580 emb.store.wait([emb_name + "_opt"])
581 state = get_shared_mem_array(
582 emb_name + "_state", emb.weight.shape, th.float32
583 )
584 else:
585 # distributed state on on gpu
586 state = th.empty(
587 emb.weight.shape,
588 dtype=th.float32,
589 device=emb.weight.device,
590 ).zero_()
591 emb.set_optm_state((state,))
592
593 def update(self, idx, grad, emb):
594 """Update embeddings in a sparse manner

Callers 1

__init__Method · 0.95

Calls 5

create_shared_mem_arrayFunction · 0.85
get_shared_mem_arrayFunction · 0.85
set_optm_stateMethod · 0.80
deviceMethod · 0.45
waitMethod · 0.45

Tested by

no test coverage detected