(device: str = "cpu")
| 88 | |
| 89 | |
| 90 | def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None: |
| 91 | img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device) |
| 92 | |
| 93 | # Step 1: Initialize model with the best available weights |
| 94 | weights = ResNet50_Weights.DEFAULT |
| 95 | model = resnet50(weights=weights, progress=False).to(device) |
| 96 | model.eval() |
| 97 | |
| 98 | # Step 2: Initialize the inference transforms |
| 99 | preprocess = weights.transforms(antialias=True) |
| 100 | |
| 101 | # Step 3: Apply inference preprocessing transforms |
| 102 | batch = preprocess(img).unsqueeze(0) |
| 103 | |
| 104 | # Step 4: Use the model and print the predicted category |
| 105 | prediction = model(batch).squeeze(0).softmax(0) |
| 106 | class_id = prediction.argmax().item() |
| 107 | score = prediction[class_id].item() |
| 108 | category_name = weights.meta["categories"][class_id] |
| 109 | expected_category = "German shepherd" |
| 110 | print(f"{category_name} ({device}): {100 * score:.1f}%") |
| 111 | if category_name != expected_category: |
| 112 | raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}") |
| 113 | |
| 114 | |
| 115 | def main() -> None: |
no test coverage detected
searching dependent graphs…