MCPcopy
hub / github.com/geekcomputers/Python / load_model

Method load_model

ML/tests/gui_test.py:347–436  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

345 self.image_preview.setText(f'Error loading image: {e}')
346
347 def load_model(self):
348 model_path = self.model_path_input.text()
349 dataset_input = self.dataset_input.text().lower().strip()
350
351 dataset_aliases = {
352 'cifar10': 'cifar10',
353 'cifar-10': 'cifar10',
354 'cifar_10': 'cifar10',
355 'cifar100': 'cifar100',
356 'cifar-100': 'cifar100',
357 'cifar_100': 'cifar100',
358 'mnist': 'mnist',
359 'fashionmnist': 'fashion_mnist',
360 'fashion-mnist': 'fashion_mnist',
361 'fashion_mnist': 'fashion_mnist',
362 'stl10': 'stl10',
363 'stl-10': 'stl10',
364 'stl_10': 'stl10',
365 'tinyimagenet': 'tiny_imagenet',
366 'tiny-imagenet': 'tiny_imagenet',
367 'tiny_imagenet': 'tiny_imagenet',
368 'imagenet': 'imagenet',
369 'food101': 'food101',
370 'food-101': 'food101',
371 'food_101': 'food101',
372 'caltech256': 'caltech256',
373 'caltech-256': 'caltech256',
374 'caltech_256': 'caltech256',
375 'oxfordpets': 'oxford_pets',
376 'oxford-pets': 'oxford_pets',
377 'oxford_pets': 'oxford_pets',
378 }
379
380 self.dataset_name = dataset_aliases.get(dataset_input, dataset_input)
381
382 if not model_path:
383 self.model_status.setText('Please select a model file')
384 self.model_status.setStyleSheet('color: #f44336;')
385 return
386
387 if not os.path.exists(model_path):
388 self.model_status.setText('Model file not found')
389 self.model_status.setStyleSheet('color: #f44336;')
390 return
391
392 try:
393 self.model_status.setText('Loading model...')
394 self.model_status.setStyleSheet('color: #FFC107;')
395 QApplication.processEvents()
396
397 num_classes = get_num_classes(self.dataset_name)
398 self.model = ResNet18(num_classes=num_classes)
399 self.model = self.model.to(self.device)
400
401 checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
402 self.model.load_state_dict(checkpoint['model_state_dict'])
403 self.model.eval()
404

Callers 2

transcript_generatorFunction · 0.45
mask_detection.pyFile · 0.45

Calls 6

get_num_classesFunction · 0.90
ResNet18Function · 0.90
get_datasetFunction · 0.90
get_class_namesFunction · 0.90
getMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected