()
| 2 | import os |
| 3 | |
| 4 | def main(): |
| 5 | try: |
| 6 | from PyQt6.QtWidgets import QApplication |
| 7 | except ImportError: |
| 8 | print("Error: PyQt6 not installed") |
| 9 | print("Install with: pip install neuralforge[gui]") |
| 10 | print("Or: pip install PyQt6") |
| 11 | sys.exit(1) |
| 12 | |
| 13 | current_dir = os.path.dirname(os.path.abspath(__file__)) |
| 14 | root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_dir)))) |
| 15 | |
| 16 | sys.path.insert(0, root_dir) |
| 17 | |
| 18 | from PyQt6.QtWidgets import (QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, |
| 19 | QPushButton, QLabel, QLineEdit, QFileDialog, |
| 20 | QProgressBar, QTextEdit, QGroupBox) |
| 21 | from PyQt6.QtCore import Qt, QThread, pyqtSignal |
| 22 | from PyQt6.QtGui import QPixmap, QFont |
| 23 | |
| 24 | import torch |
| 25 | import torch.nn.functional as F |
| 26 | from torchvision import transforms |
| 27 | from PIL import Image |
| 28 | |
| 29 | from neuralforge.data.datasets import get_dataset, get_num_classes |
| 30 | from neuralforge.models.resnet import ResNet18 |
| 31 | |
| 32 | class PredictionThread(QThread): |
| 33 | finished = pyqtSignal(list, list, str) |
| 34 | error = pyqtSignal(str) |
| 35 | |
| 36 | def __init__(self, model, image_path, classes, device): |
| 37 | super().__init__() |
| 38 | self.model = model |
| 39 | self.image_path = image_path |
| 40 | self.classes = classes |
| 41 | self.device = device |
| 42 | |
| 43 | def run(self): |
| 44 | try: |
| 45 | image = Image.open(self.image_path).convert('RGB') |
| 46 | |
| 47 | transform = transforms.Compose([ |
| 48 | transforms.Resize(256), |
| 49 | transforms.CenterCrop(224), |
| 50 | transforms.ToTensor(), |
| 51 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 52 | ]) |
| 53 | |
| 54 | image_tensor = transform(image).unsqueeze(0).to(self.device) |
| 55 | |
| 56 | with torch.no_grad(): |
| 57 | outputs = self.model(image_tensor) |
| 58 | probabilities = F.softmax(outputs, dim=1) |
| 59 | |
| 60 | top5_prob, top5_idx = torch.topk(probabilities, min(5, len(self.classes)), dim=1) |
| 61 |
no test coverage detected