MCPcopy
hub / github.com/OpenPPL/ppq / collect

Method collect

ppq/quantization/optim/training.py:224–298  ·  view source on GitHub ↗

Collect training data for given block. This function will collect fp32 output and quantized input data by executing your graph twice. For collecting fp32 output, all related operations will be dequantized. For collecting quantized input, all related opera

(
        self, graph: BaseGraph, block: TrainableBlock, executor: TorchExecutor, 
        dataloader: Iterable, collate_fn: Callable, collecting_device: str, steps: int = None,
        expire_device: str = 'cpu')

Source from the content-addressed store, hash-verified

222 return ret
223
224 def collect(
225 self, graph: BaseGraph, block: TrainableBlock, executor: TorchExecutor,
226 dataloader: Iterable, collate_fn: Callable, collecting_device: str, steps: int = None,
227 expire_device: str = 'cpu') -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]:
228 """
229 Collect training data for given block.
230 This function will collect fp32 output and quantized input data by
231 executing your graph twice.
232 For collecting fp32 output, all related operations will be dequantized.
233 For collecting quantized input, all related operations' quantization state will be restored.
234
235 collecting device declares where cache to be stored:
236 executor - store cache to executor device.(default)
237 cpu - store cache to system memory.
238 cuda - store cache to gpu memory.(2x speed up)
239 disk - not implemented.
240
241 Args:
242 block (TrainableBlock): _description_
243 executor (TorchExecutor): _description_
244 dataloader (Iterable): _description_
245 collate_fn (Callable): _description_
246 collecting_device (str): _description_
247
248 Returns:
249 _type_: _description_
250 """
251 def cache_fn(data: torch.Tensor):
252 # TODO move this function to ppq.core.IO
253 if not isinstance(data, torch.Tensor):
254 raise TypeError('Unexpected Type of value, Except network output to be torch.Tensor, '
255 f'however {type(data)} was given.')
256 if collecting_device == 'cpu': data = data.cpu()
257 if collecting_device == 'cuda': data = data.cuda()
258 # TODO restrict collecting device.
259 return data
260
261 with torch.no_grad():
262 try:
263 if len(dataloader) > 1024:
264 ppq_warning('Large finetuning dataset detected(>1024). '
265 'You are suppose to prepare a smaller dataset for finetuning. '
266 'Large dataset might cause system out of memory, '
267 'cause all data are cache in memory.')
268 except Exception as e:
269 pass # dataloader has no __len__
270
271 quant_graph = QuantableGraph(graph) # helper class
272 fp_outputs, qt_inputs = [], []
273
274 cur_iter = 0
275 # dequantize graph, collect fp32 outputs
276 quant_graph.dequantize_graph(expire_device=expire_device)
277 for data in dataloader:
278 if collate_fn is not None: data = collate_fn(data)
279 fp_output = executor.forward(data, [var.name for var in block.ep.outputs])
280 fp_output = {var.name: cache_fn(data) for data, var in zip(fp_output, block.ep.outputs)}
281 fp_outputs.append(fp_output)

Callers 4

optimizeMethod · 0.80
optimizeMethod · 0.80
optimizeMethod · 0.80
_wrapperFunction · 0.80

Calls 7

dequantize_graphMethod · 0.95
QuantableGraphClass · 0.90
ppq_warningFunction · 0.85
collate_fnFunction · 0.50
forwardMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected