()
| 9 | from src.python.neuralforge.config import Config |
| 10 | |
| 11 | def main(): |
| 12 | config = Config() |
| 13 | config.nas_enabled = True |
| 14 | config.nas_population_size = 15 |
| 15 | config.nas_generations = 20 |
| 16 | config.nas_mutation_rate = 0.15 |
| 17 | |
| 18 | search_config = { |
| 19 | 'num_layers': 15, |
| 20 | 'num_blocks': 4 |
| 21 | } |
| 22 | |
| 23 | search_space = SearchSpace(search_config) |
| 24 | |
| 25 | train_dataset = SyntheticDataset(num_samples=1000, num_classes=10) |
| 26 | val_dataset = SyntheticDataset(num_samples=200, num_classes=10) |
| 27 | |
| 28 | loader_builder = DataLoaderBuilder(config) |
| 29 | train_loader = loader_builder.build_train_loader(train_dataset) |
| 30 | val_loader = loader_builder.build_val_loader(val_dataset) |
| 31 | |
| 32 | evaluator = ProxyEvaluator(device=config.device) |
| 33 | |
| 34 | evolution = EvolutionarySearch( |
| 35 | search_space=search_space, |
| 36 | evaluator=evaluator, |
| 37 | population_size=config.nas_population_size, |
| 38 | generations=config.nas_generations, |
| 39 | mutation_rate=config.nas_mutation_rate |
| 40 | ) |
| 41 | |
| 42 | print("Starting Neural Architecture Search...") |
| 43 | best_architecture = evolution.search() |
| 44 | |
| 45 | print(f"\nBest Architecture Found:") |
| 46 | print(f"Fitness: {best_architecture.fitness:.4f}") |
| 47 | print(f"Accuracy: {best_architecture.accuracy:.2f}%") |
| 48 | print(f"Parameters: {best_architecture.params:,}") |
| 49 | print(f"FLOPs: {best_architecture.flops:,}") |
| 50 | |
| 51 | print("\nTop 5 Architectures:") |
| 52 | top_k = evolution.get_top_k_architectures(k=5) |
| 53 | for i, arch in enumerate(top_k, 1): |
| 54 | print(f"{i}. Fitness: {arch.fitness:.4f}, Acc: {arch.accuracy:.2f}%, Params: {arch.params:,}") |
| 55 | |
| 56 | model = search_space.build_model(best_architecture, num_classes=10) |
| 57 | print(f"\nModel created with {sum(p.numel() for p in model.parameters()):,} parameters") |
| 58 | |
| 59 | if __name__ == '__main__': |
| 60 | main() |
no test coverage detected