()
| 235 | print(f"Error: {e}") |
| 236 | |
| 237 | def main(): |
| 238 | import argparse |
| 239 | |
| 240 | parser = argparse.ArgumentParser(description='Test trained NeuralForge model') |
| 241 | |
| 242 | default_model = os.path.join(os.path.dirname(__file__), '..', 'models', 'best_model.pt') |
| 243 | parser.add_argument('--model', type=str, default=default_model, help='Path to model checkpoint') |
| 244 | parser.add_argument('--dataset', type=str, default='cifar10', |
| 245 | choices=['cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'stl10', |
| 246 | 'tiny_imagenet', 'imagenet', 'food101', 'caltech256', 'oxford_pets'], |
| 247 | help='Dataset to test on') |
| 248 | parser.add_argument('--device', type=str, default='cuda', help='Device to use') |
| 249 | parser.add_argument('--mode', type=str, default='interactive', |
| 250 | choices=['interactive', 'random', 'accuracy'], |
| 251 | help='Testing mode') |
| 252 | parser.add_argument('--samples', type=int, default=10, help='Number of samples for random mode') |
| 253 | args = parser.parse_args() |
| 254 | |
| 255 | tester = ModelTester(model_path=args.model, dataset=args.dataset, device=args.device) |
| 256 | |
| 257 | if args.mode == 'interactive': |
| 258 | tester.interactive_mode() |
| 259 | elif args.mode == 'random': |
| 260 | tester.test_random_samples(args.samples) |
| 261 | elif args.mode == 'accuracy': |
| 262 | tester.test_class_accuracy() |
| 263 | |
| 264 | if __name__ == '__main__': |
| 265 | main() |
no test coverage detected