Train until an evaluation stops improving. Works as a generator, with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, where info is a dict, and is_best_checkpoint is in [True, False, None] -- None indicating that the iteration was not evaluated as a checkpoint. T
(
nlp: "Language",
optimizer: Optimizer,
train_data,
evaluate,
*,
dropout: float,
eval_frequency: int,
accumulate_gradient: int,
patience: int,
max_steps: int,
exclude: List[str],
annotating_components: List[str],
before_update: Optional[Callable[["Language", Dict[str, Any]], None]],
)
| 151 | |
| 152 | |
| 153 | def train_while_improving( |
| 154 | nlp: "Language", |
| 155 | optimizer: Optimizer, |
| 156 | train_data, |
| 157 | evaluate, |
| 158 | *, |
| 159 | dropout: float, |
| 160 | eval_frequency: int, |
| 161 | accumulate_gradient: int, |
| 162 | patience: int, |
| 163 | max_steps: int, |
| 164 | exclude: List[str], |
| 165 | annotating_components: List[str], |
| 166 | before_update: Optional[Callable[["Language", Dict[str, Any]], None]], |
| 167 | ): |
| 168 | """Train until an evaluation stops improving. Works as a generator, |
| 169 | with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, |
| 170 | where info is a dict, and is_best_checkpoint is in [True, False, None] -- |
| 171 | None indicating that the iteration was not evaluated as a checkpoint. |
| 172 | The evaluation is conducted by calling the evaluate callback. |
| 173 | |
| 174 | Positional arguments: |
| 175 | nlp: The spaCy pipeline to evaluate. |
| 176 | optimizer: The optimizer callable. |
| 177 | train_data (Iterable[Batch]): A generator of batches, with the training |
| 178 | data. Each batch should be a Sized[Tuple[Input, Annot]]. The training |
| 179 | data iterable needs to take care of iterating over the epochs and |
| 180 | shuffling. |
| 181 | evaluate (Callable[[], Tuple[float, Any]]): A callback to perform evaluation. |
| 182 | The callback should take no arguments and return a tuple |
| 183 | `(main_score, other_scores)`. The main_score should be a float where |
| 184 | higher is better. other_scores can be any object. |
| 185 | |
| 186 | Every iteration, the function yields out a tuple with: |
| 187 | |
| 188 | * batch: A list of Example objects. |
| 189 | * info: A dict with various information about the last update (see below). |
| 190 | * is_best_checkpoint: A value in None, False, True, indicating whether this |
| 191 | was the best evaluation so far. You should use this to save the model |
| 192 | checkpoints during training. If None, evaluation was not conducted on |
| 193 | that iteration. False means evaluation was conducted, but a previous |
| 194 | evaluation was better. |
| 195 | |
| 196 | The info dict provides the following information: |
| 197 | |
| 198 | epoch (int): How many passes over the data have been completed. |
| 199 | step (int): How many steps have been completed. |
| 200 | score (float): The main score from the last evaluation. |
| 201 | other_scores: : The other scores from the last evaluation. |
| 202 | losses: The accumulated losses throughout training. |
| 203 | checkpoints: A list of previous results, where each result is a |
| 204 | (score, step, epoch) tuple. |
| 205 | """ |
| 206 | if isinstance(dropout, float): |
| 207 | dropouts = constant(dropout) |
| 208 | else: |
| 209 | dropouts = dropout |
| 210 | results = [] |
searching dependent graphs…