| 444 | class All2All_Wait(Function): |
| 445 | @staticmethod |
| 446 | def forward(ctx, *output): |
| 447 | global myreq |
| 448 | with record_function("DLRM alltoall_wait_fwd_single"): |
| 449 | a2a_info = myreq.a2a_info |
| 450 | ctx.a2a_info = a2a_info |
| 451 | myreq.req.wait() |
| 452 | myreq.req = None |
| 453 | myreq.tensor = None |
| 454 | table_split_lengths = ( |
| 455 | a2a_info.table_split_lengths |
| 456 | if a2a_info.table_split_lengths |
| 457 | else a2a_info.local_table_num |
| 458 | * a2a_info.local_batch_num |
| 459 | * a2a_info.emb_dim |
| 460 | ) |
| 461 | outputs = output[0].split(table_split_lengths) |
| 462 | outputs = tuple( |
| 463 | [out.view([a2a_info.local_batch_num, -1]) for out in outputs] |
| 464 | ) |
| 465 | return outputs |
| 466 | |
| 467 | @staticmethod |
| 468 | def backward(ctx, *grad_outputs): |