MCPcopy Index your code
hub / github.com/tensorflow/models / run_benchmark

Function run_benchmark

official/benchmark/benchmark_lib.py:45–153  ·  view source on GitHub ↗

Runs benchmark for a specific experiment. Args: execution_mode: A 'str', specifying the mode. Can be 'accuracy', 'performance', or 'tflite_accuracy'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. distribution_strategy

(
    execution_mode: str,
    params: config_definitions.ExperimentConfig,
    model_dir: str,
    distribution_strategy: tf.distribute.Strategy = None
)

Source from the content-addressed store, hash-verified

43
44
45def run_benchmark(
46 execution_mode: str,
47 params: config_definitions.ExperimentConfig,
48 model_dir: str,
49 distribution_strategy: tf.distribute.Strategy = None
50) -> Mapping[str, Any]:
51 """Runs benchmark for a specific experiment.
52
53 Args:
54 execution_mode: A 'str', specifying the mode. Can be 'accuracy',
55 'performance', or 'tflite_accuracy'.
56 params: ExperimentConfig instance.
57 model_dir: A 'str', a path to store model checkpoints and summaries.
58 distribution_strategy: A tf.distribute.Strategy to use. If specified,
59 it will be used instead of inferring the strategy from params.
60
61 Returns:
62 benchmark_data: returns benchmark data in dict format.
63
64 Raises:
65 NotImplementedError: If try to use unsupported setup.
66 """
67
68 # For GPU runs, allow option to set thread mode
69 if params.runtime.gpu_thread_mode:
70 os.environ['TF_GPU_THREAD_MODE'] = params.runtime.gpu_thread_mode
71 logging.info('TF_GPU_THREAD_MODE: %s', os.environ['TF_GPU_THREAD_MODE'])
72
73 # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
74 # can have significant impact on model speeds by utilizing float16 in case of
75 # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
76 # dtype is float16
77 if params.runtime.mixed_precision_dtype:
78 performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
79
80 strategy = distribution_strategy or distribute_utils.get_distribution_strategy(
81 distribution_strategy=params.runtime.distribution_strategy,
82 all_reduce_alg=params.runtime.all_reduce_alg,
83 num_gpus=params.runtime.num_gpus,
84 tpu_address=params.runtime.tpu)
85
86 with strategy.scope():
87 task = task_factory.get_task(params.task, logging_dir=model_dir)
88 trainer = train_utils.create_trainer(
89 params,
90 task,
91 train=True,
92 evaluate=(execution_mode == 'accuracy'))
93 # Initialize the model if possible, e.g., from a pre-trained checkpoint.
94 trainer.initialize()
95
96 steps_per_loop = params.trainer.steps_per_loop if (
97 execution_mode in ['accuracy', 'tflite_accuracy']) else 100
98
99 train_output_recorder = _OutputRecorderAction()
100 controller = orbit.Controller(
101 strategy=strategy,
102 trainer=trainer,

Callers

nothing calls this directly

Calls 7

trainMethod · 0.95
infoMethod · 0.80
updateMethod · 0.80
initializeMethod · 0.45
evaluateMethod · 0.45
train_and_evaluateMethod · 0.45

Tested by

no test coverage detected