MCPcopy
hub / github.com/geekcomputers/Python / load_model

Method load_model

ML/src/python/neuralforge/cli/gui.py:361–434  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

359 self.image_preview.setText(f'Error loading image: {e}')
360
361 def load_model(self):
362 model_path = self.model_path_input.text()
363 dataset_input = self.dataset_input.text().lower().strip()
364
365 dataset_aliases = {
366 'cifar10': 'cifar10', 'cifar-10': 'cifar10', 'cifar_10': 'cifar10',
367 'cifar100': 'cifar100', 'cifar-100': 'cifar100', 'cifar_100': 'cifar100',
368 'mnist': 'mnist',
369 'fashionmnist': 'fashion_mnist', 'fashion-mnist': 'fashion_mnist', 'fashion_mnist': 'fashion_mnist',
370 'stl10': 'stl10', 'stl-10': 'stl10', 'stl_10': 'stl10',
371 'tinyimagenet': 'tiny_imagenet', 'tiny-imagenet': 'tiny_imagenet', 'tiny_imagenet': 'tiny_imagenet',
372 'imagenet': 'imagenet',
373 'food101': 'food101', 'food-101': 'food101', 'food_101': 'food101',
374 'caltech256': 'caltech256', 'caltech-256': 'caltech256', 'caltech_256': 'caltech256',
375 'oxfordpets': 'oxford_pets', 'oxford-pets': 'oxford_pets', 'oxford_pets': 'oxford_pets',
376 }
377
378 self.dataset_name = dataset_aliases.get(dataset_input, dataset_input)
379
380 if not model_path:
381 self.model_status.setText('Please select a model file')
382 self.model_status.setStyleSheet('color: #f44336;')
383 return
384
385 if not os.path.exists(model_path):
386 self.model_status.setText('Model file not found')
387 self.model_status.setStyleSheet('color: #f44336;')
388 return
389
390 try:
391 self.model_status.setText('Loading model...')
392 self.model_status.setStyleSheet('color: #FFC107;')
393 QApplication.processEvents()
394
395 num_classes = get_num_classes(self.dataset_name)
396 self.model = ResNet18(num_classes=num_classes)
397 self.model = self.model.to(self.device)
398
399 checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
400 self.model.load_state_dict(checkpoint['model_state_dict'])
401 self.model.eval()
402
403 try:
404 dataset = get_dataset(self.dataset_name, train=False, download=False)
405 self.classes = getattr(dataset, 'classes', [str(i) for i in range(num_classes)])
406 except:
407 from neuralforge.data.datasets import get_class_names
408 self.classes = get_class_names(self.dataset_name)
409
410 self.model_status.setText(f'✓ Model loaded successfully')
411 self.model_status.setStyleSheet('color: #4CAF50;')
412
413 self.predict_btn.setEnabled(True)
414
415 total_params = sum(p.numel() for p in self.model.parameters())
416 epoch = checkpoint.get('epoch', 'Unknown')
417 val_loss = checkpoint.get('best_val_loss', 'Unknown')
418

Callers

nothing calls this directly

Calls 6

get_num_classesFunction · 0.90
ResNet18Function · 0.90
get_datasetFunction · 0.90
get_class_namesFunction · 0.90
getMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected