(self)
| 118 | self._comm = True |
| 119 | |
| 120 | def _shared_setup(self): |
| 121 | for emb in self._params: |
| 122 | emb_name = emb.name |
| 123 | if self._rank == 0: # the master gpu process |
| 124 | opt_meta = create_shared_mem_array( |
| 125 | emb_name + "_opt_meta", |
| 126 | (self._world_size, self._world_size), |
| 127 | th.int32, |
| 128 | ).zero_() |
| 129 | |
| 130 | if self._rank == 0: |
| 131 | emb.store.set(emb_name + "_opt_meta", emb_name) |
| 132 | self._opt_meta[emb_name] = opt_meta |
| 133 | elif self._rank > 0: |
| 134 | # receive |
| 135 | emb.store.wait([emb_name + "_opt_meta"]) |
| 136 | opt_meta = get_shared_mem_array( |
| 137 | emb_name + "_opt_meta", |
| 138 | (self._world_size, self._world_size), |
| 139 | th.int32, |
| 140 | ) |
| 141 | self._opt_meta[emb_name] = opt_meta |
| 142 | |
| 143 | def _comm_step(self): |
| 144 | with th.no_grad(): |
no test coverage detected