()
| 22 | # 和往常一样,我们要创建 calibration 数据,以及加载模型 |
| 23 | # ------------------------------------------------------------ |
| 24 | def load_calibration_dataset() -> Iterable: |
| 25 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] |
| 26 | CALIBRATION = load_calibration_dataset() |
| 27 | |
| 28 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: |