MCPcopy
hub / github.com/OpenPPL/ppq / dump_internal_results

Function dump_internal_results

ppq/api/fsys.py:197–235  ·  view source on GitHub ↗
(
    graph: BaseGraph, dataloader: Iterable,
    dump_dir: str, executing_device: str, sample: bool = True)

Source from the content-addressed store, hash-verified

195
196
197def dump_internal_results(
198 graph: BaseGraph, dataloader: Iterable,
199 dump_dir: str, executing_device: str, sample: bool = True):
200
201 i_dir = os.path.join(dump_dir, 'inputs')
202 o_dir = os.path.join(dump_dir, 'outputs')
203
204 create_dir(i_dir)
205 create_dir(o_dir)
206
207 # 找出所有量化点,抽出所有中间结果.
208 for var in graph.variables.values():
209 if isinstance(var, QuantableVariable):
210 if (var.source_op_config is not None and
211 var.source_op_config.state == QuantizationStates.ACTIVATED):
212 graph.outputs[var.name] = var # 直接标记为网络输出
213
214 executor = TorchExecutor(graph, device=executing_device)
215 for batch_idx, batch in tqdm(enumerate(dataloader),
216 total=len(dataloader), desc='Dumping Results ...'):
217 batch = batch.to(executing_device)
218 outputs = executor.forward(batch)
219
220 create_dir(os.path.join(o_dir, str(batch_idx)))
221 for name, output in zip(graph.outputs, outputs):
222
223 # 如果数字太多就抽样
224 if output.numel() > 10000 and sample:
225 output = tensor_random_fetch(
226 tensor=output, seed=10086, # 保证随机种子一致才能比对结果
227 num_of_fetches=4096)
228
229 dump_to_file(
230 filename=os.path.join(o_dir, str(batch_idx), name + '.dat'),
231 data=output, format='.dat')
232
233 dump_to_file(
234 filename=os.path.join(i_dir, str(batch_idx) + '.npy'),
235 data=batch, format='.npy')
236
237
238def split_result_to_directory(raw_dir: str, to_dir: str):

Callers

nothing calls this directly

Calls 6

forwardMethod · 0.95
TorchExecutorClass · 0.90
tensor_random_fetchFunction · 0.90
create_dirFunction · 0.85
dump_to_fileFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected