MCPcopy
hub / github.com/geekcomputers/Python / main

Function main

ML/src/python/neuralforge/cli/test.py:16–133  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

14from neuralforge.models.resnet import ResNet18
15
16def main():
17 parser = argparse.ArgumentParser(
18 description='NeuralForge - Test trained models',
19 formatter_class=argparse.RawDescriptionHelpFormatter,
20 epilog="""
21Examples:
22 neuralforge-test --model models/best_model.pt --dataset cifar10 --mode random
23 neuralforge-test --dataset mnist --mode accuracy
24 neuralforge-test --dataset stl10 --image cat.jpg
25 """
26 )
27
28 default_model = './models/best_model.pt'
29 parser.add_argument('--model', type=str, default=default_model, help='Path to model checkpoint')
30 parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset name')
31 parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
32 parser.add_argument('--mode', type=str, default='random', choices=['random', 'accuracy', 'interactive'])
33 parser.add_argument('--samples', type=int, default=10, help='Number of samples for random mode')
34 parser.add_argument('--image', type=str, default=None, help='Path to image file')
35
36 args = parser.parse_args()
37
38 print("=" * 60)
39 print(" NeuralForge - Model Testing")
40 print("=" * 60)
41 print(f"Device: {args.device}")
42
43 dataset_aliases = {
44 'cifar-10': 'cifar10', 'stl-10': 'stl10', 'fashion-mnist': 'fashion_mnist',
45 'tiny-imagenet': 'tiny_imagenet', 'food-101': 'food101',
46 }
47 dataset_name = dataset_aliases.get(args.dataset.lower(), args.dataset.lower())
48
49 num_classes = get_num_classes(dataset_name)
50 model = ResNet18(num_classes=num_classes)
51 model = model.to(args.device)
52
53 if os.path.exists(args.model):
54 print(f"Loading model from: {args.model}")
55 checkpoint = torch.load(args.model, map_location=args.device, weights_only=False)
56 model.load_state_dict(checkpoint['model_state_dict'])
57 print(f"Model loaded from epoch {checkpoint.get('epoch', 'Unknown')}")
58 else:
59 print(f"Warning: No model found at {args.model}")
60 return
61
62 model.eval()
63
64 test_dataset = get_dataset(dataset_name, root='./data', train=False, download=True)
65 classes = getattr(test_dataset, 'classes', [str(i) for i in range(num_classes)])
66
67 print(f"Dataset: {dataset_name} ({len(test_dataset.dataset)} test samples)")
68 print("=" * 60)
69
70 if args.image:
71 image = Image.open(args.image).convert('RGB')
72 transform = transforms.Compose([
73 transforms.Resize(256),

Callers 1

test.pyFile · 0.70

Calls 6

get_num_classesFunction · 0.90
ResNet18Function · 0.90
get_datasetFunction · 0.90
convertMethod · 0.80
getMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected