| 22 | # 此时你的 Calibration Dataset 应该是一个 list of dictionary |
| 23 | # ------------------------------------------------------------ |
| 24 | def generate_calibration_dataset(graph: BaseGraph, num_of_batches: int = 32) -> Tuple[Iterable[dict], torch.Tensor]: |
| 25 | dataset = [] |
| 26 | for i in range(num_of_batches): |
| 27 | sample = {name: torch.rand(INPUT_SHAPES[name]) for name in graph.inputs} |
| 28 | dataset.append(sample) |
| 29 | return dataset, sample # last sample |
| 30 | |
| 31 | def collate_fn(batch: dict) -> torch.Tensor: |
| 32 | return {k: v.to(DEVICE) for k, v in batch.items()} |