MCPcopy Index your code
hub / github.com/tensorflow/models / TestTrainer

Class TestTrainer

orbit/standard_runner_test.py:38–58  ·  view source on GitHub ↗

A StandardTrainer subclass for tests.

Source from the content-addressed store, hash-verified

36
37
38class TestTrainer(standard_runner.StandardTrainer):
39 """A StandardTrainer subclass for tests."""
40
41 def __init__(self, options=None):
42 self.strategy = tf.distribute.get_strategy()
43 self.global_step = utils.create_global_step()
44 dataset = self.strategy.distribute_datasets_from_function(dataset_fn)
45 super().__init__(train_dataset=dataset, options=options)
46
47 def train_loop_begin(self):
48 self.global_step.assign(0)
49
50 def train_step(self, iterator):
51
52 def replica_step(_):
53 self.global_step.assign_add(1)
54
55 self.strategy.run(replica_step, args=(next(iterator),))
56
57 def train_loop_end(self):
58 return self.global_step.numpy()
59
60
61class TestEvaluator(standard_runner.StandardEvaluator):

Calls

no outgoing calls