(self, model_path='./models/best_model.pt', dataset='cifar10', device='cuda')
| 14 | |
| 15 | class ModelTester: |
| 16 | def __init__(self, model_path='./models/best_model.pt', dataset='cifar10', device='cuda'): |
| 17 | self.device = device if torch.cuda.is_available() else 'cpu' |
| 18 | self.dataset_name = dataset |
| 19 | |
| 20 | print("=" * 60) |
| 21 | print(" NeuralForge - Interactive Model Testing") |
| 22 | print("=" * 60) |
| 23 | print(f"Device: {self.device}") |
| 24 | |
| 25 | num_classes = get_num_classes(dataset) |
| 26 | self.model = self.create_model(num_classes) |
| 27 | |
| 28 | if os.path.exists(model_path): |
| 29 | print(f"Loading model from: {model_path}") |
| 30 | checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| 31 | self.model.load_state_dict(checkpoint['model_state_dict']) |
| 32 | print(f"Model loaded from epoch {checkpoint['epoch']}") |
| 33 | else: |
| 34 | print(f"Warning: No model found at {model_path}, using untrained model") |
| 35 | |
| 36 | self.model.eval() |
| 37 | |
| 38 | test_dataset = get_dataset(dataset, root='./data', train=False, download=True) |
| 39 | self.dataset = test_dataset.dataset |
| 40 | self.classes = get_class_names(dataset) |
| 41 | |
| 42 | if dataset in ['mnist', 'fashion_mnist']: |
| 43 | self.image_size = 28 |
| 44 | elif dataset in ['cifar10', 'cifar100']: |
| 45 | self.image_size = 32 |
| 46 | elif dataset == 'stl10': |
| 47 | self.image_size = 96 |
| 48 | else: |
| 49 | self.image_size = 224 |
| 50 | |
| 51 | print(f"Dataset: {dataset} ({len(self.dataset)} test samples)") |
| 52 | print(f"Classes: {len(self.classes)}") |
| 53 | print("=" * 60) |
| 54 | |
| 55 | def create_model(self, num_classes): |
| 56 | model = ResNet18(num_classes=num_classes) |
nothing calls this directly
no test coverage detected