MCPcopy
hub / github.com/kohya-ss/sd-scripts / Offloader

Class Offloader

library/custom_offloading_utils.py:109–163  ·  view source on GitHub ↗

common offloading class

Source from the content-addressed store, hash-verified

107
108
109class Offloader:
110 """
111 common offloading class
112 """
113
114 def __init__(self, num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
115 self.num_blocks = num_blocks
116 self.blocks_to_swap = blocks_to_swap
117 self.device = device
118 self.debug = debug
119
120 self.thread_pool = ThreadPoolExecutor(max_workers=1)
121 self.futures = {}
122 self.cuda_available = device.type == "cuda"
123
124 def swap_weight_devices(self, block_to_cpu: nn.Module, block_to_cuda: nn.Module):
125 if self.cuda_available:
126 swap_weight_devices_cuda(self.device, block_to_cpu, block_to_cuda)
127 else:
128 swap_weight_devices_no_cuda(self.device, block_to_cpu, block_to_cuda)
129
130 def _submit_move_blocks(self, blocks, block_idx_to_cpu, block_idx_to_cuda):
131 def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
132 if self.debug:
133 start_time = time.perf_counter()
134 print(f"Move block {bidx_to_cpu} to CPU and block {bidx_to_cuda} to {'CUDA' if self.cuda_available else 'device'}")
135
136 self.swap_weight_devices(block_to_cpu, block_to_cuda)
137
138 if self.debug:
139 print(f"Moved blocks {bidx_to_cpu} and {bidx_to_cuda} in {time.perf_counter() - start_time:.2f}s")
140 return bidx_to_cpu, bidx_to_cuda # , event
141
142 block_to_cpu = blocks[block_idx_to_cpu]
143 block_to_cuda = blocks[block_idx_to_cuda]
144
145 self.futures[block_idx_to_cuda] = self.thread_pool.submit(
146 move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda
147 )
148
149 def _wait_blocks_move(self, block_idx):
150 if block_idx not in self.futures:
151 return
152
153 if self.debug:
154 print(f"Wait for block {block_idx}")
155 start_time = time.perf_counter()
156
157 future = self.futures.pop(block_idx)
158 _, bidx_to_cuda = future.result()
159
160 assert block_idx == bidx_to_cuda, f"Block index mismatch: {block_idx} != {bidx_to_cuda}"
161
162 if self.debug:
163 print(f"Waited for block {block_idx}: {time.perf_counter() - start_time:.2f}s")
164
165
166# Gradient tensors

Callers 1

offloaderFunction · 0.90

Calls

no outgoing calls

Tested by 1

offloaderFunction · 0.72