(self)
| 345 | self.image_preview.setText(f'Error loading image: {e}') |
| 346 | |
| 347 | def load_model(self): |
| 348 | model_path = self.model_path_input.text() |
| 349 | dataset_input = self.dataset_input.text().lower().strip() |
| 350 | |
| 351 | dataset_aliases = { |
| 352 | 'cifar10': 'cifar10', |
| 353 | 'cifar-10': 'cifar10', |
| 354 | 'cifar_10': 'cifar10', |
| 355 | 'cifar100': 'cifar100', |
| 356 | 'cifar-100': 'cifar100', |
| 357 | 'cifar_100': 'cifar100', |
| 358 | 'mnist': 'mnist', |
| 359 | 'fashionmnist': 'fashion_mnist', |
| 360 | 'fashion-mnist': 'fashion_mnist', |
| 361 | 'fashion_mnist': 'fashion_mnist', |
| 362 | 'stl10': 'stl10', |
| 363 | 'stl-10': 'stl10', |
| 364 | 'stl_10': 'stl10', |
| 365 | 'tinyimagenet': 'tiny_imagenet', |
| 366 | 'tiny-imagenet': 'tiny_imagenet', |
| 367 | 'tiny_imagenet': 'tiny_imagenet', |
| 368 | 'imagenet': 'imagenet', |
| 369 | 'food101': 'food101', |
| 370 | 'food-101': 'food101', |
| 371 | 'food_101': 'food101', |
| 372 | 'caltech256': 'caltech256', |
| 373 | 'caltech-256': 'caltech256', |
| 374 | 'caltech_256': 'caltech256', |
| 375 | 'oxfordpets': 'oxford_pets', |
| 376 | 'oxford-pets': 'oxford_pets', |
| 377 | 'oxford_pets': 'oxford_pets', |
| 378 | } |
| 379 | |
| 380 | self.dataset_name = dataset_aliases.get(dataset_input, dataset_input) |
| 381 | |
| 382 | if not model_path: |
| 383 | self.model_status.setText('Please select a model file') |
| 384 | self.model_status.setStyleSheet('color: #f44336;') |
| 385 | return |
| 386 | |
| 387 | if not os.path.exists(model_path): |
| 388 | self.model_status.setText('Model file not found') |
| 389 | self.model_status.setStyleSheet('color: #f44336;') |
| 390 | return |
| 391 | |
| 392 | try: |
| 393 | self.model_status.setText('Loading model...') |
| 394 | self.model_status.setStyleSheet('color: #FFC107;') |
| 395 | QApplication.processEvents() |
| 396 | |
| 397 | num_classes = get_num_classes(self.dataset_name) |
| 398 | self.model = ResNet18(num_classes=num_classes) |
| 399 | self.model = self.model.to(self.device) |
| 400 | |
| 401 | checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| 402 | self.model.load_state_dict(checkpoint['model_state_dict']) |
| 403 | self.model.eval() |
| 404 |
no test coverage detected