(ctx, *grad_output)
| 240 | |
| 241 | @staticmethod |
| 242 | def backward(ctx, *grad_output): |
| 243 | global myreq |
| 244 | for r in myreq.req: |
| 245 | r.wait() |
| 246 | myreq.req = None |
| 247 | grad_inputs = myreq.tensor |
| 248 | myreq.tensor = None |
| 249 | return (None, *grad_inputs) |
| 250 | |
| 251 | |
| 252 | class All2All_ScatterList_Wait(Function): |