MCPcopy
hub / github.com/pytorch/vision / smoke_test_torchvision_resnet50_classify

Function smoke_test_torchvision_resnet50_classify

test/smoke_test.py:90–112  ·  view source on GitHub ↗
(device: str = "cpu")

Source from the content-addressed store, hash-verified

88
89
90def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
91 img = decode_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
92
93 # Step 1: Initialize model with the best available weights
94 weights = ResNet50_Weights.DEFAULT
95 model = resnet50(weights=weights, progress=False).to(device)
96 model.eval()
97
98 # Step 2: Initialize the inference transforms
99 preprocess = weights.transforms(antialias=True)
100
101 # Step 3: Apply inference preprocessing transforms
102 batch = preprocess(img).unsqueeze(0)
103
104 # Step 4: Use the model and print the predicted category
105 prediction = model(batch).squeeze(0).softmax(0)
106 class_id = prediction.argmax().item()
107 score = prediction[class_id].item()
108 category_name = weights.meta["categories"][class_id]
109 expected_category = "German shepherd"
110 print(f"{category_name} ({device}): {100 * score:.1f}%")
111 if category_name != expected_category:
112 raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
113
114
115def main() -> None:

Callers 1

mainFunction · 0.85

Calls 6

decode_imageFunction · 0.90
resnet50Function · 0.90
preprocessFunction · 0.85
toMethod · 0.80
transformsMethod · 0.80
printFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…