()
| 46 | ) |
| 47 | |
| 48 | def main(): |
| 49 | parser = argparse.ArgumentParser(description='NeuralForge Training') |
| 50 | parser.add_argument('--config', type=str, default=None, help='Path to config file') |
| 51 | parser.add_argument('--model', type=str, default='simple', choices=['simple', 'resnet18', 'efficientnet', 'vit']) |
| 52 | parser.add_argument('--batch-size', type=int, default=32) |
| 53 | parser.add_argument('--epochs', type=int, default=50) |
| 54 | parser.add_argument('--lr', type=float, default=0.001) |
| 55 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| 56 | parser.add_argument('--num-samples', type=int, default=5000, help='Number of synthetic samples') |
| 57 | parser.add_argument('--num-classes', type=int, default=10) |
| 58 | parser.add_argument('--seed', type=int, default=42) |
| 59 | parser.add_argument('--dataset', type=str, default='synthetic', |
| 60 | choices=['synthetic', 'cifar10', 'cifar100', 'mnist', 'fashion_mnist', 'stl10', |
| 61 | 'tiny_imagenet', 'imagenet', 'food101', 'caltech256', 'oxford_pets'], |
| 62 | help='Dataset to use') |
| 63 | args = parser.parse_args() |
| 64 | |
| 65 | if args.config: |
| 66 | config = Config.load(args.config) |
| 67 | else: |
| 68 | config = Config() |
| 69 | config.batch_size = args.batch_size |
| 70 | config.epochs = args.epochs |
| 71 | config.learning_rate = args.lr |
| 72 | config.device = args.device |
| 73 | config.num_classes = args.num_classes |
| 74 | config.seed = args.seed |
| 75 | |
| 76 | set_seed(config.seed) |
| 77 | |
| 78 | logger = Logger(config.log_dir, "training") |
| 79 | logger.info("=" * 80) |
| 80 | logger.info("NeuralForge Training Framework") |
| 81 | logger.info("=" * 80) |
| 82 | logger.info(f"Configuration:\n{config}") |
| 83 | |
| 84 | if args.dataset == 'synthetic': |
| 85 | logger.info("Creating synthetic dataset...") |
| 86 | train_dataset = SyntheticDataset( |
| 87 | num_samples=args.num_samples, |
| 88 | num_classes=config.num_classes, |
| 89 | image_size=config.image_size, |
| 90 | channels=3 |
| 91 | ) |
| 92 | |
| 93 | val_dataset = SyntheticDataset( |
| 94 | num_samples=args.num_samples // 5, |
| 95 | num_classes=config.num_classes, |
| 96 | image_size=config.image_size, |
| 97 | channels=3 |
| 98 | ) |
| 99 | else: |
| 100 | logger.info(f"Downloading and loading {args.dataset} dataset...") |
| 101 | config.num_classes = get_num_classes(args.dataset) |
| 102 | |
| 103 | train_dataset = get_dataset(args.dataset, root=config.data_path, train=True, download=True) |
| 104 | val_dataset = get_dataset(args.dataset, root=config.data_path, train=False, download=True) |
| 105 |
no test coverage detected