MCPcopy
hub / github.com/facebookresearch/dlrm / backward

Method backward

extend_distributed.py:429–441  ·  view source on GitHub ↗
(ctx, *grad_output)

Source from the content-addressed store, hash-verified

427
428 @staticmethod
429 def backward(ctx, *grad_output):
430 global myreq
431 with record_function("DLRM alltoall_req_bwd_single"):
432 a2a_info = ctx.a2a_info
433 myreq.req.wait()
434 myreq.req = None
435 grad_input = myreq.tensor
436 grad_inputs = grad_input.view([a2a_info.batch_size, -1]).split(
437 a2a_info.emb_dim, dim=1
438 )
439 grad_inputs = [gin.contiguous() for gin in grad_inputs]
440 myreq.tensor = None
441 return (None, *grad_inputs)
442
443
444class All2All_Wait(Function):

Callers

nothing calls this directly

Calls 1

waitMethod · 0.80

Tested by

no test coverage detected