(batch: torch.Tensor)
| 15 | CALIBRATION = [torch.rand(size=INPUT_SHAPE) for _ in range(32)] |
| 16 | QS = QuantizationSettingFactory.default_setting() |
| 17 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: |
| 18 | return batch.to(DEVICE) |
| 19 | |
| 20 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) |
| 21 | model = model.to(DEVICE) |