(self, params)
| 553 | self.setup(self._params) |
| 554 | |
| 555 | def setup(self, params): |
| 556 | # We need to register a state sum for each embedding in the kvstore. |
| 557 | for emb in params: |
| 558 | assert isinstance( |
| 559 | emb, NodeEmbedding |
| 560 | ), "SparseAdagrad only supports dgl.nn.NodeEmbedding" |
| 561 | |
| 562 | emb_name = emb.name |
| 563 | if th.device(emb.weight.device) == th.device("cpu"): |
| 564 | # if our embedding is on the CPU, our state also has to be |
| 565 | if self._rank < 0: |
| 566 | state = th.empty( |
| 567 | emb.weight.shape, |
| 568 | dtype=th.float32, |
| 569 | device=th.device("cpu"), |
| 570 | ).zero_() |
| 571 | elif self._rank == 0: |
| 572 | state = create_shared_mem_array( |
| 573 | emb_name + "_state", emb.weight.shape, th.float32 |
| 574 | ).zero_() |
| 575 | |
| 576 | if self._world_size > 1: |
| 577 | emb.store.set(emb_name + "_opt", emb_name) |
| 578 | elif self._rank > 0: |
| 579 | # receive |
| 580 | emb.store.wait([emb_name + "_opt"]) |
| 581 | state = get_shared_mem_array( |
| 582 | emb_name + "_state", emb.weight.shape, th.float32 |
| 583 | ) |
| 584 | else: |
| 585 | # distributed state on on gpu |
| 586 | state = th.empty( |
| 587 | emb.weight.shape, |
| 588 | dtype=th.float32, |
| 589 | device=emb.weight.device, |
| 590 | ).zero_() |
| 591 | emb.set_optm_state((state,)) |
| 592 | |
| 593 | def update(self, idx, grad, emb): |
| 594 | """Update embeddings in a sparse manner |
no test coverage detected