MCPcopy
hub / github.com/geekcomputers/Python / __init__

Method __init__

ML/tests/test_model.py:16–53  ·  view source on GitHub ↗
(self, model_path='./models/best_model.pt', dataset='cifar10', device='cuda')

Source from the content-addressed store, hash-verified

14
15class ModelTester:
16 def __init__(self, model_path='./models/best_model.pt', dataset='cifar10', device='cuda'):
17 self.device = device if torch.cuda.is_available() else 'cpu'
18 self.dataset_name = dataset
19
20 print("=" * 60)
21 print(" NeuralForge - Interactive Model Testing")
22 print("=" * 60)
23 print(f"Device: {self.device}")
24
25 num_classes = get_num_classes(dataset)
26 self.model = self.create_model(num_classes)
27
28 if os.path.exists(model_path):
29 print(f"Loading model from: {model_path}")
30 checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
31 self.model.load_state_dict(checkpoint['model_state_dict'])
32 print(f"Model loaded from epoch {checkpoint['epoch']}")
33 else:
34 print(f"Warning: No model found at {model_path}, using untrained model")
35
36 self.model.eval()
37
38 test_dataset = get_dataset(dataset, root='./data', train=False, download=True)
39 self.dataset = test_dataset.dataset
40 self.classes = get_class_names(dataset)
41
42 if dataset in ['mnist', 'fashion_mnist']:
43 self.image_size = 28
44 elif dataset in ['cifar10', 'cifar100']:
45 self.image_size = 32
46 elif dataset == 'stl10':
47 self.image_size = 96
48 else:
49 self.image_size = 224
50
51 print(f"Dataset: {dataset} ({len(self.dataset)} test samples)")
52 print(f"Classes: {len(self.classes)}")
53 print("=" * 60)
54
55 def create_model(self, num_classes):
56 model = ResNet18(num_classes=num_classes)

Callers

nothing calls this directly

Calls 5

create_modelMethod · 0.95
get_num_classesFunction · 0.90
get_datasetFunction · 0.90
get_class_namesFunction · 0.90
loadMethod · 0.45

Tested by

no test coverage detected