supports forward offloading
| 168 | |
| 169 | |
| 170 | class ModelOffloader(Offloader): |
| 171 | """ |
| 172 | supports forward offloading |
| 173 | """ |
| 174 | |
| 175 | def __init__( |
| 176 | self, |
| 177 | blocks: Union[list[nn.Module], nn.ModuleList], |
| 178 | blocks_to_swap: int, |
| 179 | device: torch.device, |
| 180 | supports_backward: bool = True, |
| 181 | debug: bool = False, |
| 182 | ): |
| 183 | super().__init__(len(blocks), blocks_to_swap, device, debug) |
| 184 | |
| 185 | self.supports_backward = supports_backward |
| 186 | self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference |
| 187 | |
| 188 | if self.supports_backward: |
| 189 | # register backward hooks |
| 190 | self.remove_handles = [] |
| 191 | for i, block in enumerate(blocks): |
| 192 | hook = self.create_backward_hook(blocks, i) |
| 193 | if hook is not None: |
| 194 | handle = block.register_full_backward_hook(hook) |
| 195 | self.remove_handles.append(handle) |
| 196 | |
| 197 | def set_forward_only(self, forward_only: bool): |
| 198 | # switching must wait for all pending transfers |
| 199 | for block_idx in list(self.futures.keys()): |
| 200 | self._wait_blocks_move(block_idx) |
| 201 | self.forward_only = forward_only |
| 202 | |
| 203 | def __del__(self): |
| 204 | if self.supports_backward: |
| 205 | for handle in self.remove_handles: |
| 206 | handle.remove() |
| 207 | |
| 208 | def create_backward_hook( |
| 209 | self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int |
| 210 | ) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]: |
| 211 | # -1 for 0-based index |
| 212 | num_blocks_propagated = self.num_blocks - block_index - 1 |
| 213 | swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap |
| 214 | waiting = block_index > 0 and block_index <= self.blocks_to_swap |
| 215 | |
| 216 | if not swapping and not waiting: |
| 217 | return None |
| 218 | |
| 219 | # create hook |
| 220 | block_idx_to_cpu = self.num_blocks - num_blocks_propagated |
| 221 | block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated |
| 222 | block_idx_to_wait = block_index - 1 |
| 223 | |
| 224 | def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t): |
| 225 | if self.debug: |
| 226 | print(f"Backward hook for block {block_index}") |
| 227 |
no outgoing calls