| 62 | self.error.emit(str(e)) |
| 63 | |
| 64 | class NeuralForgeGUI(QMainWindow): |
| 65 | def __init__(self): |
| 66 | super().__init__() |
| 67 | self.model = None |
| 68 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 69 | self.classes = [] |
| 70 | self.dataset_name = 'cifar10' |
| 71 | |
| 72 | self.init_ui() |
| 73 | self.apply_stylesheet() |
| 74 | |
| 75 | def init_ui(self): |
| 76 | self.setWindowTitle('NeuralForge - Model Tester') |
| 77 | self.setGeometry(100, 100, 1200, 800) |
| 78 | |
| 79 | central_widget = QWidget() |
| 80 | self.setCentralWidget(central_widget) |
| 81 | |
| 82 | main_layout = QHBoxLayout() |
| 83 | central_widget.setLayout(main_layout) |
| 84 | |
| 85 | left_panel = self.create_left_panel() |
| 86 | right_panel = self.create_right_panel() |
| 87 | |
| 88 | main_layout.addWidget(left_panel, 1) |
| 89 | main_layout.addWidget(right_panel, 1) |
| 90 | |
| 91 | def create_left_panel(self): |
| 92 | panel = QWidget() |
| 93 | layout = QVBoxLayout() |
| 94 | panel.setLayout(layout) |
| 95 | |
| 96 | title = QLabel('🚀 NeuralForge Model Tester') |
| 97 | title.setFont(QFont('Arial', 20, QFont.Weight.Bold)) |
| 98 | title.setAlignment(Qt.AlignmentFlag.AlignCenter) |
| 99 | layout.addWidget(title) |
| 100 | |
| 101 | model_group = QGroupBox('Model Selection') |
| 102 | model_layout = QVBoxLayout() |
| 103 | |
| 104 | model_path_layout = QHBoxLayout() |
| 105 | self.model_path_input = QLineEdit() |
| 106 | self.model_path_input.setPlaceholderText('Path to model file (.pt)') |
| 107 | model_path_layout.addWidget(self.model_path_input) |
| 108 | |
| 109 | browse_btn = QPushButton('Browse') |
| 110 | browse_btn.clicked.connect(self.browse_model) |
| 111 | model_path_layout.addWidget(browse_btn) |
| 112 | |
| 113 | default_btn = QPushButton('Use Default') |
| 114 | default_btn.clicked.connect(self.use_default_model) |
| 115 | model_path_layout.addWidget(default_btn) |
| 116 | |
| 117 | model_layout.addLayout(model_path_layout) |
| 118 | |
| 119 | dataset_layout = QHBoxLayout() |
| 120 | dataset_label = QLabel('Dataset:') |
| 121 | self.dataset_input = QLineEdit('cifar10') |