(ctx, *grad_output)
| 427 | |
| 428 | @staticmethod |
| 429 | def backward(ctx, *grad_output): |
| 430 | global myreq |
| 431 | with record_function("DLRM alltoall_req_bwd_single"): |
| 432 | a2a_info = ctx.a2a_info |
| 433 | myreq.req.wait() |
| 434 | myreq.req = None |
| 435 | grad_input = myreq.tensor |
| 436 | grad_inputs = grad_input.view([a2a_info.batch_size, -1]).split( |
| 437 | a2a_info.emb_dim, dim=1 |
| 438 | ) |
| 439 | grad_inputs = [gin.contiguous() for gin in grad_inputs] |
| 440 | myreq.tensor = None |
| 441 | return (None, *grad_inputs) |
| 442 | |
| 443 | |
| 444 | class All2All_Wait(Function): |