MCPcopy
hub / github.com/tensorflow/models / run_experiment

Function run_experiment

official/modeling/multitask/train_lib.py:39–150  ·  view source on GitHub ↗

Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'cont

(
    *,
    distribution_strategy: tf.distribute.Strategy,
    task: multitask.MultiTask,
    model: base_model.MultiTaskBaseModel | tf_keras.Model,
    mode: str,
    params: configs.MultiTaskExperimentConfig,
    model_dir: str,
    run_post_eval: bool = False,
    trainer: base_trainer.MultiTaskBaseTrainer = None,
    eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
    best_ckpt_exporter_creator: Optional[Any] = train_utils
    .maybe_create_best_ckpt_exporter
)

Source from the content-addressed store, hash-verified

37
38
39def run_experiment(
40 *,
41 distribution_strategy: tf.distribute.Strategy,
42 task: multitask.MultiTask,
43 model: base_model.MultiTaskBaseModel | tf_keras.Model,
44 mode: str,
45 params: configs.MultiTaskExperimentConfig,
46 model_dir: str,
47 run_post_eval: bool = False,
48 trainer: base_trainer.MultiTaskBaseTrainer = None,
49 eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
50 best_ckpt_exporter_creator: Optional[Any] = train_utils
51 .maybe_create_best_ckpt_exporter
52) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel,
53 Mapping[Any, Any]]]:
54 """Runs train/eval configured by the experiment params.
55
56 Args:
57 distribution_strategy: A distribution distribution_strategy.
58 task: A MultiTaskTask instance.
59 model: A MultiTaskBaseModel instance.
60 mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
61 or 'continuous_eval'.
62 params: ExperimentConfig instance.
63 model_dir: A 'str', a path to store model checkpoints and summaries.
64 run_post_eval: Whether to run post eval once after training, metrics logs
65 are returned.
66 trainer: (optional) A multi-task trainer to use. If none is provided, a
67 default one will be created based on `params`.
68 eval_summary_manager: Instance of the eval summary manager. If set, the
69 `eval_summary_dir` will be ignored. Otherwise the eval summary manager
70 will be created internally for TensorBoard summaries by default from the
71 `eval_summary_dir`.
72 best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
73
74 Returns:
75 model: `base_model.MultiTaskBaseModel` instance.
76 """
77
78 is_training = 'train' in mode
79 is_eval = 'eval' in mode
80 with distribution_strategy.scope():
81 if is_training and trainer is None:
82 trainer = get_trainer(distribution_strategy, params, task, model)
83 if is_eval:
84 eval_steps = task.task_eval_steps
85 evaluator = evaluator_lib.MultiTaskEvaluator(
86 eval_tasks=task.tasks.values(),
87 model=model,
88 eval_steps=eval_steps,
89 global_step=trainer.global_step if is_training else None,
90 checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
91 else:
92 evaluator = None
93
94 if trainer:
95 checkpoint = trainer.checkpoint
96 global_step = trainer.global_step

Callers

nothing calls this directly

Calls 8

trainMethod · 0.95
train_and_evaluateMethod · 0.95
evaluateMethod · 0.95
evaluate_continuouslyMethod · 0.95
evaluateMethod · 0.95
get_trainerFunction · 0.85
infoMethod · 0.80
joinMethod · 0.45

Tested by

no test coverage detected