common offloading class
| 107 | |
| 108 | |
| 109 | class Offloader: |
| 110 | """ |
| 111 | common offloading class |
| 112 | """ |
| 113 | |
| 114 | def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False): |
| 115 | self.num_blocks = num_blocks |
| 116 | self.blocks_to_swap = blocks_to_swap |
| 117 | self.device = device |
| 118 | self.debug = debug |
| 119 | |
| 120 | self.thread_pool = ThreadPoolExecutor(max_workers=1) |
| 121 | self.futures = {} |
| 122 | self.cuda_available = device.type == "cuda" |
| 123 | |
| 124 | def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module): |
| 125 | if self.cuda_available: |
| 126 | swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda) |
| 127 | else: |
| 128 | swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda) |
| 129 | |
| 130 | def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda): |
| 131 | def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda): |
| 132 | if self.debug: |
| 133 | start_time = time.perf_counter() |
| 134 | print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}") |
| 135 | |
| 136 | self.swap_weight_devices(block_to_cpu, block_to_cuda) |
| 137 | |
| 138 | if self.debug: |
| 139 | print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s") |
| 140 | return bidx_to_cpu, bidx_to_cuda # , event |
| 141 | |
| 142 | block_to_cpu = blocks[block_idx_to_cpu] |
| 143 | block_to_cuda = blocks[block_idx_to_cuda] |
| 144 | |
| 145 | self.futures[block_idx_to_cuda] = self.thread_pool.submit( |
| 146 | move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda |
| 147 | ) |
| 148 | |
| 149 | def _wait_blocks_move(self, block_idx): |
| 150 | if block_idx not in self.futures: |
| 151 | return |
| 152 | |
| 153 | if self.debug: |
| 154 | print(f"Wait for block {block_idx}") |
| 155 | start_time = time.perf_counter() |
| 156 | |
| 157 | future = self.futures.pop(block_idx) |
| 158 | _, bidx_to_cuda = future.result() |
| 159 | |
| 160 | assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}" |
| 161 | |
| 162 | if self.debug: |
| 163 | print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s") |
| 164 | |
| 165 | |
| 166 | # Gradient tensors |