(self)
| 359 | self.image_preview.setText(f'Error loading image: {e}') |
| 360 | |
| 361 | def load_model(self): |
| 362 | model_path = self.model_path_input.text() |
| 363 | dataset_input = self.dataset_input.text().lower().strip() |
| 364 | |
| 365 | dataset_aliases = { |
| 366 | 'cifar10': 'cifar10', 'cifar-10': 'cifar10', 'cifar_10': 'cifar10', |
| 367 | 'cifar100': 'cifar100', 'cifar-100': 'cifar100', 'cifar_100': 'cifar100', |
| 368 | 'mnist': 'mnist', |
| 369 | 'fashionmnist': 'fashion_mnist', 'fashion-mnist': 'fashion_mnist', 'fashion_mnist': 'fashion_mnist', |
| 370 | 'stl10': 'stl10', 'stl-10': 'stl10', 'stl_10': 'stl10', |
| 371 | 'tinyimagenet': 'tiny_imagenet', 'tiny-imagenet': 'tiny_imagenet', 'tiny_imagenet': 'tiny_imagenet', |
| 372 | 'imagenet': 'imagenet', |
| 373 | 'food101': 'food101', 'food-101': 'food101', 'food_101': 'food101', |
| 374 | 'caltech256': 'caltech256', 'caltech-256': 'caltech256', 'caltech_256': 'caltech256', |
| 375 | 'oxfordpets': 'oxford_pets', 'oxford-pets': 'oxford_pets', 'oxford_pets': 'oxford_pets', |
| 376 | } |
| 377 | |
| 378 | self.dataset_name = dataset_aliases.get(dataset_input, dataset_input) |
| 379 | |
| 380 | if not model_path: |
| 381 | self.model_status.setText('Please select a model file') |
| 382 | self.model_status.setStyleSheet('color: #f44336;') |
| 383 | return |
| 384 | |
| 385 | if not os.path.exists(model_path): |
| 386 | self.model_status.setText('Model file not found') |
| 387 | self.model_status.setStyleSheet('color: #f44336;') |
| 388 | return |
| 389 | |
| 390 | try: |
| 391 | self.model_status.setText('Loading model...') |
| 392 | self.model_status.setStyleSheet('color: #FFC107;') |
| 393 | QApplication.processEvents() |
| 394 | |
| 395 | num_classes = get_num_classes(self.dataset_name) |
| 396 | self.model = ResNet18(num_classes=num_classes) |
| 397 | self.model = self.model.to(self.device) |
| 398 | |
| 399 | checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| 400 | self.model.load_state_dict(checkpoint['model_state_dict']) |
| 401 | self.model.eval() |
| 402 | |
| 403 | try: |
| 404 | dataset = get_dataset(self.dataset_name, train=False, download=False) |
| 405 | self.classes = getattr(dataset, 'classes', [str(i) for i in range(num_classes)]) |
| 406 | except: |
| 407 | from neuralforge.data.datasets import get_class_names |
| 408 | self.classes = get_class_names(self.dataset_name) |
| 409 | |
| 410 | self.model_status.setText(f'✓ Model loaded successfully') |
| 411 | self.model_status.setStyleSheet('color: #4CAF50;') |
| 412 | |
| 413 | self.predict_btn.setEnabled(True) |
| 414 | |
| 415 | total_params = sum(p.numel() for p in self.model.parameters()) |
| 416 | epoch = checkpoint.get('epoch', 'Unknown') |
| 417 | val_loss = checkpoint.get('best_val_loss', 'Unknown') |
| 418 |
nothing calls this directly
no test coverage detected