Update the models in the pipeline. examples (Iterable[Example]): A batch of examples _: Should not be set - serves to catch backwards-incompatible scripts. drop (float): The dropout rate. sgd (Optimizer): An optimizer. losses (Dict[str, float]): Dictionary to
(
self,
examples: Iterable[Example],
_: Optional[Any] = None,
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = SimpleFrozenList(),
annotates: Iterable[str] = SimpleFrozenList(),
)
| 1143 | return doc |
| 1144 | |
| 1145 | def update( |
| 1146 | self, |
| 1147 | examples: Iterable[Example], |
| 1148 | _: Optional[Any] = None, |
| 1149 | *, |
| 1150 | drop: float = 0.0, |
| 1151 | sgd: Optional[Optimizer] = None, |
| 1152 | losses: Optional[Dict[str, float]] = None, |
| 1153 | component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, |
| 1154 | exclude: Iterable[str] = SimpleFrozenList(), |
| 1155 | annotates: Iterable[str] = SimpleFrozenList(), |
| 1156 | ): |
| 1157 | """Update the models in the pipeline. |
| 1158 | |
| 1159 | examples (Iterable[Example]): A batch of examples |
| 1160 | _: Should not be set - serves to catch backwards-incompatible scripts. |
| 1161 | drop (float): The dropout rate. |
| 1162 | sgd (Optimizer): An optimizer. |
| 1163 | losses (Dict[str, float]): Dictionary to update with the loss, keyed by |
| 1164 | component. |
| 1165 | component_cfg (Dict[str, Dict]): Config parameters for specific pipeline |
| 1166 | components, keyed by component name. |
| 1167 | exclude (Iterable[str]): Names of components that shouldn't be updated. |
| 1168 | annotates (Iterable[str]): Names of components that should set |
| 1169 | annotations on the predicted examples after updating. |
| 1170 | RETURNS (Dict[str, float]): The updated losses dictionary |
| 1171 | |
| 1172 | DOCS: https://spacy.io/api/language#update |
| 1173 | """ |
| 1174 | if _ is not None: |
| 1175 | raise ValueError(Errors.E989) |
| 1176 | if losses is None: |
| 1177 | losses = {} |
| 1178 | if isinstance(examples, list) and len(examples) == 0: |
| 1179 | return losses |
| 1180 | validate_examples(examples, "Language.update") |
| 1181 | examples = _copy_examples(examples) |
| 1182 | if sgd is None: |
| 1183 | if self._optimizer is None: |
| 1184 | self._optimizer = self.create_optimizer() |
| 1185 | sgd = self._optimizer |
| 1186 | if component_cfg is None: |
| 1187 | component_cfg = {} |
| 1188 | pipe_kwargs = {} |
| 1189 | for i, (name, proc) in enumerate(self.pipeline): |
| 1190 | component_cfg.setdefault(name, {}) |
| 1191 | pipe_kwargs[name] = deepcopy(component_cfg[name]) |
| 1192 | component_cfg[name].setdefault("drop", drop) |
| 1193 | pipe_kwargs[name].setdefault("batch_size", self.batch_size) |
| 1194 | for name, proc in self.pipeline: |
| 1195 | # ignore statements are used here because mypy ignores hasattr |
| 1196 | if name not in exclude and hasattr(proc, "update"): |
| 1197 | proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) # type: ignore |
| 1198 | if sgd not in (None, False): |
| 1199 | if ( |
| 1200 | name not in exclude |
| 1201 | and isinstance(proc, ty.TrainableComponent) |
| 1202 | and proc.is_trainable |