Test Alignment with TensorRT Engine and PPQ Graph.
(engine_file: str, graph: BaseGraph, samples: Iterable, collate_fn: Callable = None)
| 326 | |
| 327 | |
| 328 | def TestAlignment(engine_file: str, graph: BaseGraph, samples: Iterable, collate_fn: Callable = None) -> dict: |
| 329 | """ Test Alignment with TensorRT Engine and PPQ Graph. """ |
| 330 | logger = trt.Logger(trt.Logger.ERROR) |
| 331 | |
| 332 | feed_dicts = [] |
| 333 | for sample in samples: |
| 334 | if collate_fn is not None: sample = collate_fn(sample) |
| 335 | feed_dict = {} |
| 336 | if isinstance(sample, torch.Tensor): |
| 337 | assert len(graph.inputs) == 1, 'Graph Needs More than 1 input tensor, however only 1 was given.' |
| 338 | for name in graph.inputs: |
| 339 | feed_dict[name] = sample |
| 340 | elif isinstance(sample, list): |
| 341 | for name, value in zip(graph.inputs, sample): |
| 342 | feed_dict[name] = value |
| 343 | elif isinstance(sample, dict): |
| 344 | feed_dict = sample |
| 345 | else: |
| 346 | raise TypeError('Given Sample is Invalid.') |
| 347 | feed_dicts.append(feed_dict) |
| 348 | |
| 349 | TensorRT_Results, PPQ_Results = [], [] |
| 350 | with ENABLE_CUDA_KERNEL(): |
| 351 | executor = TorchExecutor(graph) |
| 352 | for feed_dict in tqdm(feed_dicts, desc='PPQ Infer...'): |
| 353 | PPQ_Results.append([value.cpu() for value in executor.forward(feed_dict)]) |
| 354 | |
| 355 | import pycuda.autoinit |
| 356 | with open(engine_file, 'rb') as f, trt.Runtime(logger) as runtime: |
| 357 | engine = runtime.deserialize_cuda_engine(f.read()) |
| 358 | |
| 359 | with engine.create_execution_context() as context: |
| 360 | inputs, outputs, bindings, stream = allocate_buffers(context.engine) |
| 361 | |
| 362 | for feed_dict in tqdm(feed_dicts, desc='TensorRT Infer...'): |
| 363 | for name, input in zip(graph.inputs, inputs): |
| 364 | input.host = convert_any_to_numpy(feed_dict[name]) |
| 365 | |
| 366 | results = do_inference_v2( |
| 367 | context, bindings=bindings, inputs=inputs, |
| 368 | outputs=outputs, stream=stream) |
| 369 | |
| 370 | TensorRT_Results.append([convert_any_to_torch_tensor(value) for value in results]) |
| 371 | |
| 372 | collector = {} |
| 373 | for ref, pred in zip(TensorRT_Results, PPQ_Results): |
| 374 | for name in graph.outputs: collector[name] = 0.0 |
| 375 | for name, ref_, pred_ in zip(graph.outputs, ref, pred): |
| 376 | collector[name] += torch_snr_error(pred_.reshape([1, -1]), ref_.reshape([1, -1])).item() |
| 377 | |
| 378 | for name, value in collector.items(): |
| 379 | collector[name] = value / len(samples) |
| 380 | return collector |
nothing calls this directly
no test coverage detected