(
graph: BaseGraph, dataloader: Iterable,
dump_dir: str, executing_device: str, sample: bool = True)
| 195 | |
| 196 | |
| 197 | def dump_internal_results( |
| 198 | graph: BaseGraph, dataloader: Iterable, |
| 199 | dump_dir: str, executing_device: str, sample: bool = True): |
| 200 | |
| 201 | i_dir = os.path.join(dump_dir, 'inputs') |
| 202 | o_dir = os.path.join(dump_dir, 'outputs') |
| 203 | |
| 204 | create_dir(i_dir) |
| 205 | create_dir(o_dir) |
| 206 | |
| 207 | # 找出所有量化点,抽出所有中间结果. |
| 208 | for var in graph.variables.values(): |
| 209 | if isinstance(var, QuantableVariable): |
| 210 | if (var.source_op_config is not None and |
| 211 | var.source_op_config.state == QuantizationStates.ACTIVATED): |
| 212 | graph.outputs[var.name] = var # 直接标记为网络输出 |
| 213 | |
| 214 | executor = TorchExecutor(graph, device=executing_device) |
| 215 | for batch_idx, batch in tqdm(enumerate(dataloader), |
| 216 | total=len(dataloader), desc='Dumping Results ...'): |
| 217 | batch = batch.to(executing_device) |
| 218 | outputs = executor.forward(batch) |
| 219 | |
| 220 | create_dir(os.path.join(o_dir, str(batch_idx))) |
| 221 | for name, output in zip(graph.outputs, outputs): |
| 222 | |
| 223 | # 如果数字太多就抽样 |
| 224 | if output.numel() > 10000 and sample: |
| 225 | output = tensor_random_fetch( |
| 226 | tensor=output, seed=10086, # 保证随机种子一致才能比对结果 |
| 227 | num_of_fetches=4096) |
| 228 | |
| 229 | dump_to_file( |
| 230 | filename=os.path.join(o_dir, str(batch_idx), name + '.dat'), |
| 231 | data=output, format='.dat') |
| 232 | |
| 233 | dump_to_file( |
| 234 | filename=os.path.join(i_dir, str(batch_idx) + '.npy'), |
| 235 | data=batch, format='.npy') |
| 236 | |
| 237 | |
| 238 | def split_result_to_directory(raw_dir: str, to_dir: str): |
nothing calls this directly
no test coverage detected