Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'cont
(
*,
distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel | tf_keras.Model,
mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
trainer: base_trainer.MultiTaskBaseTrainer = None,
eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter
)
| 37 | |
| 38 | |
| 39 | def run_experiment( |
| 40 | *, |
| 41 | distribution_strategy: tf.distribute.Strategy, |
| 42 | task: multitask.MultiTask, |
| 43 | model: base_model.MultiTaskBaseModel | tf_keras.Model, |
| 44 | mode: str, |
| 45 | params: configs.MultiTaskExperimentConfig, |
| 46 | model_dir: str, |
| 47 | run_post_eval: bool = False, |
| 48 | trainer: base_trainer.MultiTaskBaseTrainer = None, |
| 49 | eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None, |
| 50 | best_ckpt_exporter_creator: Optional[Any] = train_utils |
| 51 | .maybe_create_best_ckpt_exporter |
| 52 | ) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel, |
| 53 | Mapping[Any, Any]]]: |
| 54 | """Runs train/eval configured by the experiment params. |
| 55 | |
| 56 | Args: |
| 57 | distribution_strategy: A distribution distribution_strategy. |
| 58 | task: A MultiTaskTask instance. |
| 59 | model: A MultiTaskBaseModel instance. |
| 60 | mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' |
| 61 | or 'continuous_eval'. |
| 62 | params: ExperimentConfig instance. |
| 63 | model_dir: A 'str', a path to store model checkpoints and summaries. |
| 64 | run_post_eval: Whether to run post eval once after training, metrics logs |
| 65 | are returned. |
| 66 | trainer: (optional) A multi-task trainer to use. If none is provided, a |
| 67 | default one will be created based on `params`. |
| 68 | eval_summary_manager: Instance of the eval summary manager. If set, the |
| 69 | `eval_summary_dir` will be ignored. Otherwise the eval summary manager |
| 70 | will be created internally for TensorBoard summaries by default from the |
| 71 | `eval_summary_dir`. |
| 72 | best_ckpt_exporter_creator: A functor for creating best checkpoint exporter. |
| 73 | |
| 74 | Returns: |
| 75 | model: `base_model.MultiTaskBaseModel` instance. |
| 76 | """ |
| 77 | |
| 78 | is_training = 'train' in mode |
| 79 | is_eval = 'eval' in mode |
| 80 | with distribution_strategy.scope(): |
| 81 | if is_training and trainer is None: |
| 82 | trainer = get_trainer(distribution_strategy, params, task, model) |
| 83 | if is_eval: |
| 84 | eval_steps = task.task_eval_steps |
| 85 | evaluator = evaluator_lib.MultiTaskEvaluator( |
| 86 | eval_tasks=task.tasks.values(), |
| 87 | model=model, |
| 88 | eval_steps=eval_steps, |
| 89 | global_step=trainer.global_step if is_training else None, |
| 90 | checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir)) |
| 91 | else: |
| 92 | evaluator = None |
| 93 | |
| 94 | if trainer: |
| 95 | checkpoint = trainer.checkpoint |
| 96 | global_step = trainer.global_step |
nothing calls this directly
no test coverage detected