()
| 8 | from neuralforge.config import Config |
| 9 | |
| 10 | def main(): |
| 11 | parser = argparse.ArgumentParser( |
| 12 | description='NeuralForge - Neural Architecture Search', |
| 13 | formatter_class=argparse.RawDescriptionHelpFormatter, |
| 14 | epilog=""" |
| 15 | Examples: |
| 16 | neuralforge-nas --population 20 --generations 50 |
| 17 | neuralforge-nas --dataset cifar10 --population 15 --generations 30 |
| 18 | """ |
| 19 | ) |
| 20 | |
| 21 | parser.add_argument('--dataset', type=str, default='synthetic', help='Dataset for evaluation') |
| 22 | parser.add_argument('--population', type=int, default=15, help='Population size') |
| 23 | parser.add_argument('--generations', type=int, default=20, help='Number of generations') |
| 24 | parser.add_argument('--mutation-rate', type=float, default=0.15, help='Mutation rate') |
| 25 | parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') |
| 26 | |
| 27 | args = parser.parse_args() |
| 28 | |
| 29 | config = Config() |
| 30 | config.device = args.device |
| 31 | config.nas_enabled = True |
| 32 | config.nas_population_size = args.population |
| 33 | config.nas_generations = args.generations |
| 34 | config.nas_mutation_rate = args.mutation_rate |
| 35 | |
| 36 | search_config = { |
| 37 | 'num_layers': 15, |
| 38 | 'num_blocks': 4 |
| 39 | } |
| 40 | |
| 41 | search_space = SearchSpace(search_config) |
| 42 | |
| 43 | train_dataset = SyntheticDataset(num_samples=1000, num_classes=10) |
| 44 | val_dataset = SyntheticDataset(num_samples=200, num_classes=10) |
| 45 | |
| 46 | loader_builder = DataLoaderBuilder(config) |
| 47 | train_loader = loader_builder.build_train_loader(train_dataset) |
| 48 | val_loader = loader_builder.build_val_loader(val_dataset) |
| 49 | |
| 50 | evaluator = ProxyEvaluator(device=config.device) |
| 51 | |
| 52 | evolution = EvolutionarySearch( |
| 53 | search_space=search_space, |
| 54 | evaluator=evaluator, |
| 55 | population_size=config.nas_population_size, |
| 56 | generations=config.nas_generations, |
| 57 | mutation_rate=config.nas_mutation_rate |
| 58 | ) |
| 59 | |
| 60 | print("Starting Neural Architecture Search...") |
| 61 | best_architecture = evolution.search() |
| 62 | |
| 63 | print(f"\nBest Architecture Found:") |
| 64 | print(f"Fitness: {best_architecture.fitness:.4f}") |
| 65 | print(f"Accuracy: {best_architecture.accuracy:.2f}%") |
| 66 | print(f"Parameters: {best_architecture.params:,}") |
| 67 | print(f"FLOPs: {best_architecture.flops:,}") |
no test coverage detected