A StandardTrainer subclass for tests.
| 36 | |
| 37 | |
| 38 | class TestTrainer(standard_runner.StandardTrainer): |
| 39 | """A StandardTrainer subclass for tests.""" |
| 40 | |
| 41 | def __init__(self, options=None): |
| 42 | self.strategy = tf.distribute.get_strategy() |
| 43 | self.global_step = utils.create_global_step() |
| 44 | dataset = self.strategy.distribute_datasets_from_function(dataset_fn) |
| 45 | super().__init__(train_dataset=dataset, options=options) |
| 46 | |
| 47 | def train_loop_begin(self): |
| 48 | self.global_step.assign(0) |
| 49 | |
| 50 | def train_step(self, iterator): |
| 51 | |
| 52 | def replica_step(_): |
| 53 | self.global_step.assign_add(1) |
| 54 | |
| 55 | self.strategy.run(replica_step, args=(next(iterator),)) |
| 56 | |
| 57 | def train_loop_end(self): |
| 58 | return self.global_step.numpy() |
| 59 | |
| 60 | |
| 61 | class TestEvaluator(standard_runner.StandardEvaluator): |
no outgoing calls