MCPcopy Index your code
hub / github.com/dmlc/dgl / step

Method step

python/dgl/distributed/optim/pytorch/sparse_optim.py:252–428  ·  view source on GitHub ↗

The step function. The step function is invoked at the end of every batch to push the gradients of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.

(self)

Source from the content-addressed store, hash-verified

250 )
251
252 def step(self):
253 """The step function.
254
255 The step function is invoked at the end of every batch to push the gradients
256 of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
257 """
258 with th.no_grad():
259 # [Rui]
260 # As `gloo` supports CPU tensors only while `nccl` supports GPU
261 # tensors only, we firstly create tensors on the corresponding
262 # devices and then copy the data to target device if needed.
263 # Please note that the target device can be different from the
264 # preferred device.
265 target_device = None
266 preferred_device = (
267 th.device(f"cuda:{self._rank}")
268 if th.distributed.get_backend() == "nccl"
269 else th.device("cpu")
270 )
271 local_indics = {emb.name: [] for emb in self._params}
272 local_grads = {emb.name: [] for emb in self._params}
273 for emb in self._params:
274 name = emb.weight.name
275 kvstore = emb.weight.kvstore
276 trainers_per_server = self._world_size // kvstore.num_servers
277
278 idics = []
279 grads = []
280 for trace in emb._trace:
281 if trace[1].grad is not None:
282 idics.append(trace[0])
283 grads.append(trace[1].grad.data)
284 else:
285 assert len(trace[0]) == 0
286 # If the sparse embedding is not used in the previous forward step
287 # The idx and grad will be empty, initialize them as empty tensors to
288 # avoid crashing the optimizer step logic.
289 #
290 # Note: we cannot skip the gradient exchange and update steps as other
291 # working processes may send gradient update requests corresponding
292 # to certain embedding to this process.
293 #
294 # [WARNING][TODO][Rui]
295 # For empty idx and grad, we blindly create data on the
296 # preferred device, which may not be the device where the
297 # embedding is stored.
298 idics = (
299 th.cat(idics, dim=0)
300 if len(idics) != 0
301 else th.zeros((0,), dtype=th.int64, device=preferred_device)
302 )
303 grads = (
304 th.cat(grads, dim=0)
305 if len(grads) != 0
306 else th.zeros(
307 (0, emb.embedding_dim),
308 dtype=th.float32,
309 device=preferred_device,

Callers

nothing calls this directly

Calls 10

updateMethod · 0.95
alltoallFunction · 0.85
alltoallvFunction · 0.85
appendMethod · 0.80
get_partidMethod · 0.80
deviceMethod · 0.45
longMethod · 0.45
reset_traceMethod · 0.45
toMethod · 0.45
barrierMethod · 0.45

Tested by

no test coverage detected