Implements the training and evaluation APIs for the test model.
| 59 | |
| 60 | |
| 61 | class TestRunner(standard_runner.StandardTrainer, |
| 62 | standard_runner.StandardEvaluator): |
| 63 | """Implements the training and evaluation APIs for the test model.""" |
| 64 | |
| 65 | def __init__(self, return_numpy=False): |
| 66 | self.strategy = tf.distribute.get_strategy() |
| 67 | self.model = create_model() |
| 68 | self.optimizer = tf_keras.optimizers.RMSprop(learning_rate=0.1) |
| 69 | self.global_step = self.optimizer.iterations |
| 70 | self.train_loss = tf_keras.metrics.Mean("train_loss", dtype=tf.float32) |
| 71 | self.eval_loss = tf_keras.metrics.Mean("eval_loss", dtype=tf.float32) |
| 72 | self.return_numpy = return_numpy |
| 73 | train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) |
| 74 | eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn) |
| 75 | standard_runner.StandardTrainer.__init__(self, train_dataset) |
| 76 | standard_runner.StandardEvaluator.__init__(self, eval_dataset) |
| 77 | |
| 78 | def train_step(self, iterator): |
| 79 | |
| 80 | def _replicated_step(inputs): |
| 81 | """Replicated training step.""" |
| 82 | inputs, targets = inputs |
| 83 | with tf.GradientTape() as tape: |
| 84 | outputs = self.model(inputs) |
| 85 | loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) |
| 86 | grads = tape.gradient(loss, self.model.variables) |
| 87 | self.optimizer.apply_gradients(zip(grads, self.model.variables)) |
| 88 | self.train_loss.update_state(loss) |
| 89 | |
| 90 | self.strategy.run(_replicated_step, args=(next(iterator),)) |
| 91 | |
| 92 | def train_loop_end(self): |
| 93 | train_loss = self.train_loss.result() |
| 94 | return { |
| 95 | "loss": train_loss.numpy() if self.return_numpy else train_loss, |
| 96 | } |
| 97 | |
| 98 | def build_eval_dataset(self): |
| 99 | return self.strategy.distribute_datasets_from_function(dataset_fn) |
| 100 | |
| 101 | def eval_begin(self): |
| 102 | self.eval_loss.reset_states() |
| 103 | |
| 104 | def eval_step(self, iterator): |
| 105 | |
| 106 | def _replicated_step(inputs): |
| 107 | """Replicated evaluation step.""" |
| 108 | inputs, targets = inputs |
| 109 | outputs = self.model(inputs) |
| 110 | loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs)) |
| 111 | self.eval_loss.update_state(loss) |
| 112 | |
| 113 | self.strategy.run(_replicated_step, args=(next(iterator),)) |
| 114 | |
| 115 | def eval_end(self): |
| 116 | eval_loss = self.eval_loss.result() |
| 117 | return { |
| 118 | "eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss, |
no outgoing calls