| 633 | |
| 634 | @staticmethod |
| 635 | def allocate_workspace(mapping: Mapping, |
| 636 | size: int) -> Tuple[List[IpcMemory], "torch.tensor"]: |
| 637 | import torch |
| 638 | |
| 639 | # Force pull mode and disable lamport when force deterministic is enabled, for reducing device memory usage. |
| 640 | force_deterministic = force_all_reduce_deterministic() |
| 641 | is_p2p_supported = can_access_peer(mapping) |
| 642 | ipc_buffers_size = size if force_deterministic else size * mapping.tp_size |
| 643 | ipc_buffers_ping = IpcMemory(mapping, ipc_buffers_size, |
| 644 | is_p2p_supported) |
| 645 | ipc_buffers_pong = IpcMemory(mapping, ipc_buffers_size, |
| 646 | is_p2p_supported) |
| 647 | ipc_barriers_in = IpcMemory( |
| 648 | mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size * 2 * |
| 649 | mapping.tp_size, is_p2p_supported) |
| 650 | ipc_barriers_out = IpcMemory( |
| 651 | mapping, IpcMemory.IPC_BARRIERS_SIZE_PER_GPU * mapping.tp_size * 2 * |
| 652 | mapping.tp_size, is_p2p_supported) |
| 653 | lamport_buffers_size = 1 if force_deterministic else size * mapping.tp_size |
| 654 | lamport_buffers_0 = IpcMemory(mapping, lamport_buffers_size, |
| 655 | is_p2p_supported) |
| 656 | lamport_buffers_1 = IpcMemory(mapping, lamport_buffers_size, |
| 657 | is_p2p_supported) |
| 658 | lamport_buffers_2 = IpcMemory(mapping, lamport_buffers_size, |
| 659 | is_p2p_supported) |
| 660 | # TODO: it seems we may need to initialize lamport buffers for all tp groups |
| 661 | # just like its cpp counterpart (AllReduceBuffers::AllReduceBuffers()) does. |
| 662 | if is_p2p_supported: |
| 663 | lamport_initialize_all( |
| 664 | lamport_buffers_0.local_ptr, |
| 665 | lamport_buffers_1.local_ptr, |
| 666 | lamport_buffers_2.local_ptr, |
| 667 | lamport_buffers_size, |
| 668 | ) |
| 669 | buffers = [ |
| 670 | ipc_buffers_ping, |
| 671 | ipc_buffers_pong, |
| 672 | ipc_barriers_in, |
| 673 | ipc_barriers_out, |
| 674 | lamport_buffers_0, |
| 675 | lamport_buffers_1, |
| 676 | lamport_buffers_2, |
| 677 | # Start from 1 since 0 represents released state for barrier at the beginning of the all_reduce. |
| 678 | # The last element is the barrier flag counter. |
| 679 | torch.tensor([1, 1, 0], dtype=torch.int64, device="cuda") |
| 680 | ] |
| 681 | |
| 682 | return buffers, torch.tensor( |
| 683 | ipc_buffers_ping.serialize() + ipc_buffers_pong.serialize() + |
| 684 | ipc_barriers_in.serialize() + ipc_barriers_out.serialize() + |
| 685 | lamport_buffers_0.serialize() + lamport_buffers_1.serialize() + |
| 686 | lamport_buffers_2.serialize() + [buffers[-1].data_ptr()] + |
| 687 | [buffers[-1][1:].data_ptr()] + [buffers[-1][2:].data_ptr()], |
| 688 | dtype=torch.int64, |
| 689 | device="cpu") |
| 690 | |
| 691 | @staticmethod |
| 692 | def allocate_lowprecision_workspace( |