MCPcopy Index your code
hub / github.com/tensorflow/models / TestRunner

Class TestRunner

orbit/controller_test.py:61–119  ·  view source on GitHub ↗

Implements the training and evaluation APIs for the test model.

Source from the content-addressed store, hash-verified

59
60
61class TestRunner(standard_runner.StandardTrainer,
62 standard_runner.StandardEvaluator):
63 """Implements the training and evaluation APIs for the test model."""
64
65 def __init__(self, return_numpy=False):
66 self.strategy = tf.distribute.get_strategy()
67 self.model = create_model()
68 self.optimizer = tf_keras.optimizers.RMSprop(learning_rate=0.1)
69 self.global_step = self.optimizer.iterations
70 self.train_loss = tf_keras.metrics.Mean("train_loss", dtype=tf.float32)
71 self.eval_loss = tf_keras.metrics.Mean("eval_loss", dtype=tf.float32)
72 self.return_numpy = return_numpy
73 train_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
74 eval_dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
75 standard_runner.StandardTrainer.__init__(self, train_dataset)
76 standard_runner.StandardEvaluator.__init__(self, eval_dataset)
77
78 def train_step(self, iterator):
79
80 def _replicated_step(inputs):
81 """Replicated training step."""
82 inputs, targets = inputs
83 with tf.GradientTape() as tape:
84 outputs = self.model(inputs)
85 loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
86 grads = tape.gradient(loss, self.model.variables)
87 self.optimizer.apply_gradients(zip(grads, self.model.variables))
88 self.train_loss.update_state(loss)
89
90 self.strategy.run(_replicated_step, args=(next(iterator),))
91
92 def train_loop_end(self):
93 train_loss = self.train_loss.result()
94 return {
95 "loss": train_loss.numpy() if self.return_numpy else train_loss,
96 }
97
98 def build_eval_dataset(self):
99 return self.strategy.distribute_datasets_from_function(dataset_fn)
100
101 def eval_begin(self):
102 self.eval_loss.reset_states()
103
104 def eval_step(self, iterator):
105
106 def _replicated_step(inputs):
107 """Replicated evaluation step."""
108 inputs, targets = inputs
109 outputs = self.model(inputs)
110 loss = tf.reduce_mean(tf_keras.losses.MSE(targets, outputs))
111 self.eval_loss.update_state(loss)
112
113 self.strategy.run(_replicated_step, args=(next(iterator),))
114
115 def eval_end(self):
116 eval_loss = self.eval_loss.result()
117 return {
118 "eval_loss": eval_loss.numpy() if self.return_numpy else eval_loss,

Calls

no outgoing calls