()
| 44 | ) |
| 45 | |
| 46 | def main(): |
| 47 | parser = argparse.ArgumentParser( |
| 48 | description='NeuralForge - Train neural networks with CUDA acceleration', |
| 49 | formatter_class=argparse.RawDescriptionHelpFormatter, |
| 50 | epilog=""" |
| 51 | Examples: |
| 52 | neuralforge --dataset cifar10 --epochs 50 |
| 53 | neuralforge --dataset mnist --model simple --batch-size 64 |
| 54 | neuralforge --dataset stl10 --model resnet18 --epochs 100 --lr 0.001 |
| 55 | neuralforge --dataset tiny_imagenet --batch-size 128 --epochs 200 |
| 56 | """ |
| 57 | ) |
| 58 | |
| 59 | parser.add_argument('--config', type=str, default=None, help='Path to config file') |
| 60 | parser.add_argument('--model', type=str, default='simple', |
| 61 | choices=['simple', 'resnet18', 'efficientnet', 'vit'], |
| 62 | help='Model architecture') |
| 63 | parser.add_argument('--dataset', type=str, default='synthetic', |
| 64 | help='Dataset (cifar10, mnist, stl10, tiny_imagenet, etc.)') |
| 65 | parser.add_argument('--batch-size', type=int, default=32, help='Batch size') |
| 66 | parser.add_argument('--epochs', type=int, default=50, help='Number of epochs') |
| 67 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') |
| 68 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', |
| 69 | help='Device (cuda/cpu)') |
| 70 | parser.add_argument('--num-samples', type=int, default=5000, help='Number of synthetic samples') |
| 71 | parser.add_argument('--num-classes', type=int, default=10, help='Number of classes (for synthetic)') |
| 72 | parser.add_argument('--seed', type=int, default=42, help='Random seed') |
| 73 | parser.add_argument('--optimizer', type=str, default='adamw', |
| 74 | choices=['adamw', 'adam', 'sgd'], |
| 75 | help='Optimizer') |
| 76 | parser.add_argument('--scheduler', type=str, default='cosine', |
| 77 | choices=['cosine', 'onecycle', 'none'], |
| 78 | help='Learning rate scheduler') |
| 79 | |
| 80 | args = parser.parse_args() |
| 81 | |
| 82 | if args.config: |
| 83 | config = Config.load(args.config) |
| 84 | else: |
| 85 | config = Config() |
| 86 | config.batch_size = args.batch_size |
| 87 | config.epochs = args.epochs |
| 88 | config.learning_rate = args.lr |
| 89 | config.device = args.device |
| 90 | config.num_classes = args.num_classes |
| 91 | config.seed = args.seed |
| 92 | config.optimizer = args.optimizer |
| 93 | config.scheduler = args.scheduler |
| 94 | |
| 95 | # Set paths relative to current working directory (not package directory) |
| 96 | import os |
| 97 | cwd = os.getcwd() |
| 98 | config.model_dir = os.path.join(cwd, "models") |
| 99 | config.log_dir = os.path.join(cwd, "logs") |
| 100 | config.data_path = os.path.join(cwd, "data") |
| 101 | |
| 102 | set_seed(config.seed) |
| 103 |
no test coverage detected