Creates a multi-task trainer for the given task. Args: distribution_strategy: A distribution strategy. params: ExperimentConfig instance. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. Returns: An Orbit trainer instance.
(
distribution_strategy: tf.distribute.Strategy,
params: configs.MultiEvalExperimentConfig,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel | tf_keras.Model,
)
| 151 | |
| 152 | |
| 153 | def get_trainer( |
| 154 | distribution_strategy: tf.distribute.Strategy, |
| 155 | params: configs.MultiEvalExperimentConfig, |
| 156 | task: multitask.MultiTask, |
| 157 | model: base_model.MultiTaskBaseModel | tf_keras.Model, |
| 158 | ) -> orbit.StandardTrainer: |
| 159 | """Creates a multi-task trainer for the given task. |
| 160 | |
| 161 | Args: |
| 162 | distribution_strategy: A distribution strategy. |
| 163 | params: ExperimentConfig instance. |
| 164 | task: A MultiTaskTask instance. |
| 165 | model: A MultiTaskBaseModel instance. |
| 166 | |
| 167 | Returns: |
| 168 | An Orbit trainer instance. |
| 169 | """ |
| 170 | with distribution_strategy.scope(): |
| 171 | kwargs = dict( |
| 172 | multi_task=task, |
| 173 | multi_task_model=model, |
| 174 | optimizer=train_utils.create_optimizer(task, params), |
| 175 | ) |
| 176 | if params.trainer.trainer_type == 'interleaving': |
| 177 | kwargs.update( |
| 178 | task_sampler=task_sampler.get_task_sampler( |
| 179 | params.trainer.task_sampler, task.task_weights |
| 180 | ) |
| 181 | ) |
| 182 | return TRAINERS[params.trainer.trainer_type](**kwargs) |
| 183 | |
| 184 | |
| 185 | TrainActionsFactoryType = Callable[ |
no test coverage detected