The step function. The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
(self)
| 250 | ) |
| 251 | |
| 252 | def step(self): |
| 253 | """The step function. |
| 254 | |
| 255 | The step function is invoked at the end of every batch to push the gradients |
| 256 | of the embeddings involved in a mini-batch to DGL's servers and update the embeddings. |
| 257 | """ |
| 258 | with th.no_grad(): |
| 259 | # [Rui] |
| 260 | # As `gloo` supports CPU tensors only while `nccl` supports GPU |
| 261 | # tensors only, we firstly create tensors on the corresponding |
| 262 | # devices and then copy the data to target device if needed. |
| 263 | # Please note that the target device can be different from the |
| 264 | # preferred device. |
| 265 | target_device = None |
| 266 | preferred_device = ( |
| 267 | th.device(f"cuda:{self._rank}") |
| 268 | if th.distributed.get_backend() == "nccl" |
| 269 | else th.device("cpu") |
| 270 | ) |
| 271 | local_indics = {emb.name: [] for emb in self._params} |
| 272 | local_grads = {emb.name: [] for emb in self._params} |
| 273 | for emb in self._params: |
| 274 | name = emb.weight.name |
| 275 | kvstore = emb.weight.kvstore |
| 276 | trainers_per_server = self._world_size // kvstore.num_servers |
| 277 | |
| 278 | idics = [] |
| 279 | grads = [] |
| 280 | for trace in emb._trace: |
| 281 | if trace[1].grad is not None: |
| 282 | idics.append(trace[0]) |
| 283 | grads.append(trace[1].grad.data) |
| 284 | else: |
| 285 | assert len(trace[0]) == 0 |
| 286 | # If the sparse embedding is not used in the previous forward step |
| 287 | # The idx and grad will be empty, initialize them as empty tensors to |
| 288 | # avoid crashing the optimizer step logic. |
| 289 | # |
| 290 | # Note: we cannot skip the gradient exchange and update steps as other |
| 291 | # working processes may send gradient update requests corresponding |
| 292 | # to certain embedding to this process. |
| 293 | # |
| 294 | # [WARNING][TODO][Rui] |
| 295 | # For empty idx and grad, we blindly create data on the |
| 296 | # preferred device, which may not be the device where the |
| 297 | # embedding is stored. |
| 298 | idics = ( |
| 299 | th.cat(idics, dim=0) |
| 300 | if len(idics) != 0 |
| 301 | else th.zeros((0,), dtype=th.int64, device=preferred_device) |
| 302 | ) |
| 303 | grads = ( |
| 304 | th.cat(grads, dim=0) |
| 305 | if len(grads) != 0 |
| 306 | else th.zeros( |
| 307 | (0, emb.embedding_dim), |
| 308 | dtype=th.float32, |
| 309 | device=preferred_device, |
nothing calls this directly
no test coverage detected