MCPcopy
hub / github.com/kohya-ss/sd-scripts / ModelOffloader

Class ModelOffloader

library/custom_offloading_utils.py:170–276  ·  view source on GitHub ↗

supports forward offloading

Source from the content-addressed store, hash-verified

168
169
170class ModelOffloader(Offloader):
171 """
172 supports forward offloading
173 """
174
175 def __init__(
176 self,
177 blocks: Union[list[nn.Module], nn.ModuleList],
178 blocks_to_swap: int,
179 device: torch.device,
180 supports_backward: bool = True,
181 debug: bool = False,
182 ):
183 super().__init__(len(blocks), blocks_to_swap, device, debug)
184
185 self.supports_backward = supports_backward
186 self.forward_only = not supports_backward # forward only offloading: can be changed to True for inference
187
188 if self.supports_backward:
189 # register backward hooks
190 self.remove_handles = []
191 for i, block in enumerate(blocks):
192 hook = self.create_backward_hook(blocks, i)
193 if hook is not None:
194 handle = block.register_full_backward_hook(hook)
195 self.remove_handles.append(handle)
196
197 def set_forward_only(self, forward_only: bool):
198 # switching must wait for all pending transfers
199 for block_idx in list(self.futures.keys()):
200 self._wait_blocks_move(block_idx)
201 self.forward_only = forward_only
202
203 def __del__(self):
204 if self.supports_backward:
205 for handle in self.remove_handles:
206 handle.remove()
207
208 def create_backward_hook(
209 self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int
210 ) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
211 # -1 for 0-based index
212 num_blocks_propagated = self.num_blocks - block_index - 1
213 swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
214 waiting = block_index > 0 and block_index <= self.blocks_to_swap
215
216 if not swapping and not waiting:
217 return None
218
219 # create hook
220 block_idx_to_cpu = self.num_blocks - num_blocks_propagated
221 block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
222 block_idx_to_wait = block_index - 1
223
224 def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
225 if self.debug:
226 print(f"Backward hook for block {block_index}")
227

Callers 4

model_offloaderFunction · 0.90

Calls

no outgoing calls

Tested by 4

model_offloaderFunction · 0.72