Run `self.algo` iteratively (use existing `self.trials` to produce the new ones), update, and repeat block_until_done means that the process blocks until ALL jobs in trials are not in running or new state
(self, N, block_until_done=True)
| 226 | self.serial_evaluate() |
| 227 | |
| 228 | def run(self, N, block_until_done=True): |
| 229 | """ |
| 230 | Run `self.algo` iteratively (use existing `self.trials` to produce the new |
| 231 | ones), update, and repeat |
| 232 | block_until_done means that the process blocks until ALL jobs in |
| 233 | trials are not in running or new state |
| 234 | |
| 235 | """ |
| 236 | trials = self.trials |
| 237 | algo = self.algo |
| 238 | n_queued = 0 |
| 239 | |
| 240 | def get_queue_len(): |
| 241 | return self.trials.count_by_state_unsynced(base.JOB_STATE_NEW) |
| 242 | |
| 243 | def get_n_done(): |
| 244 | return self.trials.count_by_state_unsynced(base.JOB_STATE_DONE) |
| 245 | |
| 246 | def get_n_unfinished(): |
| 247 | unfinished_states = [base.JOB_STATE_NEW, base.JOB_STATE_RUNNING] |
| 248 | return self.trials.count_by_state_unsynced(unfinished_states) |
| 249 | |
| 250 | stopped = False |
| 251 | initial_n_done = get_n_done() |
| 252 | with self.progress_callback( |
| 253 | initial=initial_n_done, total=self.max_evals |
| 254 | ) as progress_ctx: |
| 255 | |
| 256 | all_trials_complete = False |
| 257 | best_loss = float("inf") |
| 258 | while ( |
| 259 | # more run to Q || ( block_flag & trials not done ) |
| 260 | (n_queued < N or (block_until_done and not all_trials_complete)) |
| 261 | # no timeout || < current last time |
| 262 | and (self.timeout is None or (timer() - self.start_time) < self.timeout) |
| 263 | # no loss_threshold || < current best_loss |
| 264 | and (self.loss_threshold is None or best_loss >= self.loss_threshold) |
| 265 | ): |
| 266 | qlen = get_queue_len() |
| 267 | while ( |
| 268 | qlen < self.max_queue_len and n_queued < N and not self.is_cancelled |
| 269 | ): |
| 270 | n_to_enqueue = min(self.max_queue_len - qlen, N - n_queued) |
| 271 | # get ids for next trials to enqueue |
| 272 | new_ids = trials.new_trial_ids(n_to_enqueue) |
| 273 | self.trials.refresh() |
| 274 | # Based on existing trials and the domain, use `algo` to probe in |
| 275 | # new hp points. Save the results of those inspections into |
| 276 | # `new_trials`. This is the core of `run`, all the rest is just |
| 277 | # processes orchestration |
| 278 | new_trials = algo( |
| 279 | new_ids, self.domain, trials, self.rstate.integers(2 ** 31 - 1) |
| 280 | ) |
| 281 | assert len(new_ids) >= len(new_trials) |
| 282 | |
| 283 | if len(new_trials): |
| 284 | self.trials.insert_trial_docs(new_trials) |
| 285 | self.trials.refresh() |
no test coverage detected