(self)
| 141 | self._opt_meta[emb_name] = opt_meta |
| 142 | |
| 143 | def _comm_step(self): |
| 144 | with th.no_grad(): |
| 145 | idx_in = {} |
| 146 | grad_in = {} |
| 147 | for emb in self._params: # pylint: disable=too-many-nested-blocks |
| 148 | emb_name = emb.name |
| 149 | partition = emb.partition |
| 150 | |
| 151 | if not partition: |
| 152 | # use default partitioning |
| 153 | partition = NDArrayPartition( |
| 154 | emb.num_embeddings, |
| 155 | self._world_size if self._world_size > 0 else 1, |
| 156 | mode="remainder", |
| 157 | ) |
| 158 | |
| 159 | # we need to combine gradients from multiple forward paths |
| 160 | if len(emb._trace) == 0: |
| 161 | idx = th.zeros((0,), dtype=th.long, device=self._device) |
| 162 | grad = th.zeros( |
| 163 | (0, emb.embedding_dim), |
| 164 | dtype=th.float32, |
| 165 | device=self._device, |
| 166 | ) |
| 167 | elif len(emb._trace) == 1: |
| 168 | # the special case where we can use the tensors as is |
| 169 | # without any memcpy's |
| 170 | idx, grad = emb._trace[0] |
| 171 | grad = grad.grad.data |
| 172 | else: |
| 173 | idx = [] |
| 174 | grad = [] |
| 175 | for i, data in emb._trace: |
| 176 | idx.append(i) |
| 177 | grad.append(data.grad.data) |
| 178 | idx = th.cat(idx, dim=0) |
| 179 | grad = th.cat(grad, dim=0) |
| 180 | |
| 181 | ( |
| 182 | idx_in[emb_name], |
| 183 | grad_in[emb_name], |
| 184 | ) = nccl.sparse_all_to_all_push(idx, grad, partition=partition) |
| 185 | if emb.partition: |
| 186 | # if the embedding is partitioned, map back to indexes |
| 187 | # into the local tensor |
| 188 | idx_in[emb_name] = partition.map_to_local(idx_in[emb_name]) |
| 189 | |
| 190 | if self._clean_grad: |
| 191 | # clean gradient track |
| 192 | for emb in self._params: |
| 193 | emb.reset_trace() |
| 194 | self._clean_grad = False |
| 195 | |
| 196 | for emb in self._params: |
| 197 | emb_name = emb.name |
| 198 | idx = idx_in[emb_name] |
| 199 | grad = grad_in[emb_name] |
| 200 | self.update(idx, grad, emb) |
no test coverage detected