(ctx, *grad_output)
| 331 | |
| 332 | @staticmethod |
| 333 | def backward(ctx, *grad_output): |
| 334 | global myreq |
| 335 | for r in myreq.req: |
| 336 | r.wait() |
| 337 | myreq.req = None |
| 338 | grad_input = myreq.tensor |
| 339 | grad_inputs = grad_input.split(ctx.a2a_info.emb_dim, dim=1) |
| 340 | myreq.tensor = None |
| 341 | return (None, *grad_inputs) |
| 342 | |
| 343 | |
| 344 | class All2All_Scatter_Wait(Function): |