MCPcopy Index your code
hub / github.com/tensorflow/models / TestEvaluator

Class TestEvaluator

orbit/controller_test.py:122–156  ·  view source on GitHub ↗

Implements the training and evaluation APIs for the test model.

Source from the content-addressed store, hash-verified

120
121
122class TestEvaluator(standard_runner.StandardEvaluator):
123 """Implements the training and evaluation APIs for the test model."""
124
125 def __init__(self):
126 self.strategy = tf.distribute.get_strategy()
127 self.model = create_model()
128 eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
129 standard_runner.StandardEvaluator.__init__(self, eval_dataset)
130
131 def eval_reduce(self, state, output):
132 state.append(output)
133 return state
134
135 def eval_begin(self):
136 return []
137
138 def eval_step(self, iterator):
139
140 def _replicated_step(inputs):
141 """Replicated evaluation step."""
142 inputs, targets = inputs
143 outputs = self.model(inputs)
144 loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
145 return loss
146
147 per_replica_losses = self.strategy.run(
148 _replicated_step, args=(next(iterator),))
149 mean_loss = self.strategy.reduce(
150 tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
151 return mean_loss
152
153 def eval_end(self, outputs):
154 return {
155 "eval_loss": tf.reduce_mean(outputs),
156 }
157
158
159class TestEvaluatorNoOutput(runner.AbstractEvaluator):

Callers 1

Calls

no outgoing calls

Tested by 1