MCPcopy Index your code
hub / github.com/geekcomputers/Python / main

Function main

ML/src/python/neuralforge/cli/train.py:46–205  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

44 )
45
46def main():
47 parser = argparse.ArgumentParser(
48 description='NeuralForge - Train neural networks with CUDA acceleration',
49 formatter_class=argparse.RawDescriptionHelpFormatter,
50 epilog="""
51Examples:
52 neuralforge --dataset cifar10 --epochs 50
53 neuralforge --dataset mnist --model simple --batch-size 64
54 neuralforge --dataset stl10 --model resnet18 --epochs 100 --lr 0.001
55 neuralforge --dataset tiny_imagenet --batch-size 128 --epochs 200
56 """
57 )
58
59 parser.add_argument('--config', type=str, default=None, help='Path to config file')
60 parser.add_argument('--model', type=str, default='simple',
61 choices=['simple', 'resnet18', 'efficientnet', 'vit'],
62 help='Model architecture')
63 parser.add_argument('--dataset', type=str, default='synthetic',
64 help='Dataset (cifar10, mnist, stl10, tiny_imagenet, etc.)')
65 parser.add_argument('--batch-size', type=int, default=32, help='Batch size')
66 parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
67 parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
68 parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
69 help='Device (cuda/cpu)')
70 parser.add_argument('--num-samples', type=int, default=5000, help='Number of synthetic samples')
71 parser.add_argument('--num-classes', type=int, default=10, help='Number of classes (for synthetic)')
72 parser.add_argument('--seed', type=int, default=42, help='Random seed')
73 parser.add_argument('--optimizer', type=str, default='adamw',
74 choices=['adamw', 'adam', 'sgd'],
75 help='Optimizer')
76 parser.add_argument('--scheduler', type=str, default='cosine',
77 choices=['cosine', 'onecycle', 'none'],
78 help='Learning rate scheduler')
79
80 args = parser.parse_args()
81
82 if args.config:
83 config = Config.load(args.config)
84 else:
85 config = Config()
86 config.batch_size = args.batch_size
87 config.epochs = args.epochs
88 config.learning_rate = args.lr
89 config.device = args.device
90 config.num_classes = args.num_classes
91 config.seed = args.seed
92 config.optimizer = args.optimizer
93 config.scheduler = args.scheduler
94
95 # Set paths relative to current working directory (not package directory)
96 import os
97 cwd = os.getcwd()
98 config.model_dir = os.path.join(cwd, "models")
99 config.log_dir = os.path.join(cwd, "logs")
100 config.data_path = os.path.join(cwd, "data")
101
102 set_seed(config.seed)
103

Callers 1

train.pyFile · 0.70

Calls 15

infoMethod · 0.95
build_train_loaderMethod · 0.95
build_val_loaderMethod · 0.95
log_model_summaryMethod · 0.95
trainMethod · 0.95
ConfigClass · 0.90
LoggerClass · 0.90
SyntheticDatasetClass · 0.90
get_num_classesFunction · 0.90
get_datasetFunction · 0.90
DataLoaderBuilderClass · 0.90
ResNet18Function · 0.90

Tested by

no test coverage detected