MCPcopy
hub / github.com/facebookresearch/dlrm / backward

Method backward

extend_distributed.py:333–341  ·  view source on GitHub ↗
(ctx, *grad_output)

Source from the content-addressed store, hash-verified

331
332 @staticmethod
333 def backward(ctx, *grad_output):
334 global myreq
335 for r in myreq.req:
336 r.wait()
337 myreq.req = None
338 grad_input = myreq.tensor
339 grad_inputs = grad_input.split(ctx.a2a_info.emb_dim, dim=1)
340 myreq.tensor = None
341 return (None, *grad_inputs)
342
343
344class All2All_Scatter_Wait(Function):

Callers

nothing calls this directly

Calls 1

waitMethod · 0.80

Tested by

no test coverage detected