MCPcopy Index your code
hub / github.com/geekcomputers/Python / ModelTester

Class ModelTester

ML/tests/test_model.py:15–235  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

13from src.python.neuralforge.models.resnet import ResNet18
14
15class ModelTester:
16 def __init__(self, model_path='./models/best_model.pt', dataset='cifar10', device='cuda'):
17 self.device = device if torch.cuda.is_available() else 'cpu'
18 self.dataset_name = dataset
19
20 print("=" * 60)
21 print(" NeuralForge - Interactive Model Testing")
22 print("=" * 60)
23 print(f"Device: {self.device}")
24
25 num_classes = get_num_classes(dataset)
26 self.model = self.create_model(num_classes)
27
28 if os.path.exists(model_path):
29 print(f"Loading model from: {model_path}")
30 checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
31 self.model.load_state_dict(checkpoint['model_state_dict'])
32 print(f"Model loaded from epoch {checkpoint['epoch']}")
33 else:
34 print(f"Warning: No model found at {model_path}, using untrained model")
35
36 self.model.eval()
37
38 test_dataset = get_dataset(dataset, root='./data', train=False, download=True)
39 self.dataset = test_dataset.dataset
40 self.classes = get_class_names(dataset)
41
42 if dataset in ['mnist', 'fashion_mnist']:
43 self.image_size = 28
44 elif dataset in ['cifar10', 'cifar100']:
45 self.image_size = 32
46 elif dataset == 'stl10':
47 self.image_size = 96
48 else:
49 self.image_size = 224
50
51 print(f"Dataset: {dataset} ({len(self.dataset)} test samples)")
52 print(f"Classes: {len(self.classes)}")
53 print("=" * 60)
54
55 def create_model(self, num_classes):
56 model = ResNet18(num_classes=num_classes)
57 return model.to(self.device)
58
59 def predict_image(self, image_tensor):
60 with torch.no_grad():
61 image_tensor = image_tensor.unsqueeze(0).to(self.device)
62 outputs = self.model(image_tensor)
63 probabilities = F.softmax(outputs, dim=1)
64 confidence, predicted = torch.max(probabilities, 1)
65
66 top5_prob, top5_idx = torch.topk(probabilities, min(5, len(self.classes)), dim=1)
67
68 return predicted.item(), confidence.item(), top5_idx[0].cpu().numpy(), top5_prob[0].cpu().numpy()
69
70 def test_random_samples(self, num_samples=10):
71 print(f"\nTesting {num_samples} random samples...")
72 print("-" * 60)

Callers 1

mainFunction · 0.85

Calls

no outgoing calls

Tested by 1

mainFunction · 0.68