()
| 14 | from neuralforge.models.resnet import ResNet18 |
| 15 | |
| 16 | def main(): |
| 17 | parser = argparse.ArgumentParser( |
| 18 | description='NeuralForge - Test trained models', |
| 19 | formatter_class=argparse.RawDescriptionHelpFormatter, |
| 20 | epilog=""" |
| 21 | Examples: |
| 22 | neuralforge-test --model models/best_model.pt --dataset cifar10 --mode random |
| 23 | neuralforge-test --dataset mnist --mode accuracy |
| 24 | neuralforge-test --dataset stl10 --image cat.jpg |
| 25 | """ |
| 26 | ) |
| 27 | |
| 28 | default_model = './models/best_model.pt' |
| 29 | parser.add_argument('--model', type=str, default=default_model, help='Path to model checkpoint') |
| 30 | parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset name') |
| 31 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| 32 | parser.add_argument('--mode', type=str, default='random', choices=['random', 'accuracy', 'interactive']) |
| 33 | parser.add_argument('--samples', type=int, default=10, help='Number of samples for random mode') |
| 34 | parser.add_argument('--image', type=str, default=None, help='Path to image file') |
| 35 | |
| 36 | args = parser.parse_args() |
| 37 | |
| 38 | print("=" * 60) |
| 39 | print(" NeuralForge - Model Testing") |
| 40 | print("=" * 60) |
| 41 | print(f"Device: {args.device}") |
| 42 | |
| 43 | dataset_aliases = { |
| 44 | 'cifar-10': 'cifar10', 'stl-10': 'stl10', 'fashion-mnist': 'fashion_mnist', |
| 45 | 'tiny-imagenet': 'tiny_imagenet', 'food-101': 'food101', |
| 46 | } |
| 47 | dataset_name = dataset_aliases.get(args.dataset.lower(), args.dataset.lower()) |
| 48 | |
| 49 | num_classes = get_num_classes(dataset_name) |
| 50 | model = ResNet18(num_classes=num_classes) |
| 51 | model = model.to(args.device) |
| 52 | |
| 53 | if os.path.exists(args.model): |
| 54 | print(f"Loading model from: {args.model}") |
| 55 | checkpoint = torch.load(args.model, map_location=args.device, weights_only=False) |
| 56 | model.load_state_dict(checkpoint['model_state_dict']) |
| 57 | print(f"Model loaded from epoch {checkpoint.get('epoch', 'Unknown')}") |
| 58 | else: |
| 59 | print(f"Warning: No model found at {args.model}") |
| 60 | return |
| 61 | |
| 62 | model.eval() |
| 63 | |
| 64 | test_dataset = get_dataset(dataset_name, root='./data', train=False, download=True) |
| 65 | classes = getattr(test_dataset, 'classes', [str(i) for i in range(num_classes)]) |
| 66 | |
| 67 | print(f"Dataset: {dataset_name} ({len(test_dataset.dataset)} test samples)") |
| 68 | print("=" * 60) |
| 69 | |
| 70 | if args.image: |
| 71 | image = Image.open(args.image).convert('RGB') |
| 72 | transform = transforms.Compose([ |
| 73 | transforms.Resize(256), |
no test coverage detected