(self, architecture: Architecture, search_space: SearchSpace)
| 21 | self.quick_eval = quick_eval |
| 22 | |
| 23 | def evaluate(self, architecture: Architecture, search_space: SearchSpace) -> Tuple[float, float]: |
| 24 | try: |
| 25 | model = search_space.build_model(architecture) |
| 26 | model = model.to(self.device) |
| 27 | |
| 28 | criterion = nn.CrossEntropyLoss() |
| 29 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| 30 | |
| 31 | if self.quick_eval: |
| 32 | accuracy = self._quick_evaluate(model, criterion, optimizer) |
| 33 | else: |
| 34 | accuracy = self._full_evaluate(model, criterion, optimizer) |
| 35 | |
| 36 | complexity = search_space.estimate_complexity(architecture) |
| 37 | params = complexity['params'] |
| 38 | flops = complexity['flops'] |
| 39 | |
| 40 | param_penalty = params / 1e7 |
| 41 | flop_penalty = flops / 1e9 |
| 42 | |
| 43 | fitness = accuracy - 0.1 * param_penalty - 0.05 * flop_penalty |
| 44 | |
| 45 | return fitness, accuracy |
| 46 | |
| 47 | except Exception as e: |
| 48 | print(f"Error evaluating architecture: {e}") |
| 49 | return 0.0, 0.0 |
| 50 | |
| 51 | def _quick_evaluate(self, model: nn.Module, criterion: nn.Module, optimizer: torch.optim.Optimizer) -> float: |
| 52 | model.train() |
no test coverage detected