Implements the training and evaluation APIs for the test model.
| 120 | |
| 121 | |
| 122 | class TestEvaluator(standard_runner.StandardEvaluator): |
| 123 | """Implements the training and evaluation APIs for the test model.""" |
| 124 | |
| 125 | def __init__(self): |
| 126 | self.strategy = tf.distribute.get_strategy() |
| 127 | self.model = create_model() |
| 128 | eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) |
| 129 | standard_runner.StandardEvaluator.__init__(self, eval_dataset) |
| 130 | |
| 131 | def eval_reduce(self, state, output): |
| 132 | state.append(output) |
| 133 | return state |
| 134 | |
| 135 | def eval_begin(self): |
| 136 | return [] |
| 137 | |
| 138 | def eval_step(self, iterator): |
| 139 | |
| 140 | def _replicated_step(inputs): |
| 141 | """Replicated evaluation step.""" |
| 142 | inputs, targets = inputs |
| 143 | outputs = self.model(inputs) |
| 144 | loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) |
| 145 | return loss |
| 146 | |
| 147 | per_replica_losses = self.strategy.run( |
| 148 | _replicated_step, args=(next(iterator),)) |
| 149 | mean_loss = self.strategy.reduce( |
| 150 | tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) |
| 151 | return mean_loss |
| 152 | |
| 153 | def eval_end(self, outputs): |
| 154 | return { |
| 155 | "eval_loss": tf.reduce_mean(outputs), |
| 156 | } |
| 157 | |
| 158 | |
| 159 | class TestEvaluatorNoOutput(runner.AbstractEvaluator): |
no outgoing calls