(ctx, *output)
| 252 | class All2All_ScatterList_Wait(Function): |
| 253 | @staticmethod |
| 254 | def forward(ctx, *output): |
| 255 | global myreq |
| 256 | ctx.a2a_info = myreq.a2a_info |
| 257 | for r in myreq.req: |
| 258 | r.wait() |
| 259 | myreq.req = None |
| 260 | myreq.tensor = None |
| 261 | return output |
| 262 | |
| 263 | @staticmethod |
| 264 | def backward(ctx, *grad_output): |