Collect training data for given block. This function will collect fp32 output and quantized input data by executing your graph twice. For collecting fp32 output, all related operations will be dequantized. For collecting quantized input, all related opera
(
self, graph: BaseGraph, block: TrainableBlock, executor: TorchExecutor,
dataloader: Iterable, collate_fn: Callable, collecting_device: str, steps: int = None,
expire_device: str = 'cpu')
| 222 | return ret |
| 223 | |
| 224 | def collect( |
| 225 | self, graph: BaseGraph, block: TrainableBlock, executor: TorchExecutor, |
| 226 | dataloader: Iterable, collate_fn: Callable, collecting_device: str, steps: int = None, |
| 227 | expire_device: str = 'cpu') -> Tuple[List[Dict[str, torch.Tensor]], List[Dict[str, torch.Tensor]]]: |
| 228 | """ |
| 229 | Collect training data for given block. |
| 230 | This function will collect fp32 output and quantized input data by |
| 231 | executing your graph twice. |
| 232 | For collecting fp32 output, all related operations will be dequantized. |
| 233 | For collecting quantized input, all related operations' quantization state will be restored. |
| 234 | |
| 235 | collecting device declares where cache to be stored: |
| 236 | executor - store cache to executor device.(default) |
| 237 | cpu - store cache to system memory. |
| 238 | cuda - store cache to gpu memory.(2x speed up) |
| 239 | disk - not implemented. |
| 240 | |
| 241 | Args: |
| 242 | block (TrainableBlock): _description_ |
| 243 | executor (TorchExecutor): _description_ |
| 244 | dataloader (Iterable): _description_ |
| 245 | collate_fn (Callable): _description_ |
| 246 | collecting_device (str): _description_ |
| 247 | |
| 248 | Returns: |
| 249 | _type_: _description_ |
| 250 | """ |
| 251 | def cache_fn(data: torch.Tensor): |
| 252 | # TODO move this function to ppq.core.IO |
| 253 | if not isinstance(data, torch.Tensor): |
| 254 | raise TypeError('Unexpected Type of value, Except network output to be torch.Tensor, ' |
| 255 | f'however {type(data)} was given.') |
| 256 | if collecting_device == 'cpu': data = data.cpu() |
| 257 | if collecting_device == 'cuda': data = data.cuda() |
| 258 | # TODO restrict collecting device. |
| 259 | return data |
| 260 | |
| 261 | with torch.no_grad(): |
| 262 | try: |
| 263 | if len(dataloader) > 1024: |
| 264 | ppq_warning('Large finetuning dataset detected(>1024). ' |
| 265 | 'You are suppose to prepare a smaller dataset for finetuning. ' |
| 266 | 'Large dataset might cause system out of memory, ' |
| 267 | 'cause all data are cache in memory.') |
| 268 | except Exception as e: |
| 269 | pass # dataloader has no __len__ |
| 270 | |
| 271 | quant_graph = QuantableGraph(graph) # helper class |
| 272 | fp_outputs, qt_inputs = [], [] |
| 273 | |
| 274 | cur_iter = 0 |
| 275 | # dequantize graph, collect fp32 outputs |
| 276 | quant_graph.dequantize_graph(expire_device=expire_device) |
| 277 | for data in dataloader: |
| 278 | if collate_fn is not None: data = collate_fn(data) |
| 279 | fp_output = executor.forward(data, [var.name for var in block.ep.outputs]) |
| 280 | fp_output = {var.name: cache_fn(data) for data, var in zip(fp_output, block.ep.outputs)} |
| 281 | fp_outputs.append(fp_output) |
no test coverage detected