(self, graph: BaseGraph,
dataloader: Iterable, executor: TorchExecutor,
collate_fn: Callable, **kwargs)
| 110 | return output_collector |
| 111 | |
| 112 | def optimize(self, graph: BaseGraph, |
| 113 | dataloader: Iterable, executor: TorchExecutor, |
| 114 | collate_fn: Callable, **kwargs) -> None: |
| 115 | |
| 116 | block_builder = BlockBuilder(graph=graph, topo_order=executor._executing_order) |
| 117 | |
| 118 | # check if there is any baked value inside your graph |
| 119 | for operation in graph.operations.values(): |
| 120 | if isinstance(operation, QuantableOperation): |
| 121 | for cfg, var in operation.config_with_variable: |
| 122 | if cfg.state in {QuantizationStates.BAKED, QuantizationStates.PASSIVE_BAKED}: |
| 123 | raise PermissionError('Can not apply advanced optimization pass when weight value is baked. ' |
| 124 | f'Variable {var.name} has a baked value.') |
| 125 | |
| 126 | # build all blocks, drop overlapped layers. |
| 127 | blocks, visited = [], set() |
| 128 | for op in graph.operations.values(): |
| 129 | if op in visited: continue |
| 130 | block = block_builder.build(op, limit=OPTIM_ADVOPT_GRAPH_MAXDEPTH) |
| 131 | |
| 132 | # PATCH 20220317 drop block that has no computing op. |
| 133 | if all([rp.is_computing_op == False for rp in block.rps]): continue |
| 134 | if block.sp.is_computing_op == False: continue |
| 135 | |
| 136 | for rp in block.rps: visited.add(rp) |
| 137 | blocks.append(block) |
| 138 | |
| 139 | self.initialize_checkpoints( |
| 140 | graph=graph, executor=executor, |
| 141 | dataloader=dataloader, collate_fn=collate_fn) |
| 142 | |
| 143 | block_builder = BlockBuilder(graph=graph, topo_order=executor._executing_order) |
| 144 | for bidx, block in enumerate(blocks): |
| 145 | self._bidx, self._num_of_blocks = bidx, len(blocks) |
| 146 | assert isinstance(block, TrainableBlock) |
| 147 | |
| 148 | end_op = block.ep |
| 149 | block_input = block.sp.inputs[0] |
| 150 | block_output = end_op.outputs[0] |
| 151 | |
| 152 | # dequantize prefix operations and block operations |
| 153 | for op in graph.operations.values(): |
| 154 | if isinstance(op, QuantableOperation): |
| 155 | op.dequantize() |
| 156 | # can not use dequantize_immediately cause weight has been changed. |
| 157 | # self.dequantize_immediately(op) |
| 158 | |
| 159 | fp32_outputs = self.collect_training_data( |
| 160 | output_name=block_output.name, dataloader=dataloader, |
| 161 | executor=executor, collate_fn=collate_fn) |
| 162 | |
| 163 | # quantize prefix operations and block operations |
| 164 | for op in graph.operations.values(): |
| 165 | if isinstance(op, QuantableOperation): |
| 166 | op.restore_quantize_state() |
| 167 | |
| 168 | quant_inputs = self.collect_training_data( |
| 169 | output_name= block_input.name, dataloader=dataloader, |
nothing calls this directly
no test coverage detected