MCPcopy Index your code
hub / github.com/tensorflow/models / get_trainer

Function get_trainer

official/modeling/multitask/train_lib.py:153–182  ·  view source on GitHub ↗

Creates a multi-task trainer for the given task. Args: distribution_strategy: A distribution strategy. params: ExperimentConfig instance. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. Returns: An Orbit trainer instance.

(
    distribution_strategy: tf.distribute.Strategy,
    params: configs.MultiEvalExperimentConfig,
    task: multitask.MultiTask,
    model: base_model.MultiTaskBaseModel | tf_keras.Model,
)

Source from the content-addressed store, hash-verified

151
152
153def get_trainer(
154 distribution_strategy: tf.distribute.Strategy,
155 params: configs.MultiEvalExperimentConfig,
156 task: multitask.MultiTask,
157 model: base_model.MultiTaskBaseModel | tf_keras.Model,
158) -> orbit.StandardTrainer:
159 """Creates a multi-task trainer for the given task.
160
161 Args:
162 distribution_strategy: A distribution strategy.
163 params: ExperimentConfig instance.
164 task: A MultiTaskTask instance.
165 model: A MultiTaskBaseModel instance.
166
167 Returns:
168 An Orbit trainer instance.
169 """
170 with distribution_strategy.scope():
171 kwargs = dict(
172 multi_task=task,
173 multi_task_model=model,
174 optimizer=train_utils.create_optimizer(task, params),
175 )
176 if params.trainer.trainer_type == 'interleaving':
177 kwargs.update(
178 task_sampler=task_sampler.get_task_sampler(
179 params.trainer.task_sampler, task.task_weights
180 )
181 )
182 return TRAINERS[params.trainer.trainer_type](**kwargs)
183
184
185TrainActionsFactoryType = Callable[

Callers 1

run_experimentFunction · 0.85

Calls 2

updateMethod · 0.80
create_optimizerMethod · 0.45

Tested by

no test coverage detected