(self, distribution)
| 48 | strategy_combinations.one_device_strategy, |
| 49 | ],)) |
| 50 | def test_ema_checkpointing(self, distribution): |
| 51 | with distribution.scope(): |
| 52 | directory = self.create_tempdir() |
| 53 | model = TestModel() |
| 54 | optimizer = tf_keras.optimizers.SGD() |
| 55 | optimizer = optimization.ExponentialMovingAverage( |
| 56 | optimizer, trainable_weights_only=False) |
| 57 | |
| 58 | # Creats average weights for the model variables. Average weights are |
| 59 | # initialized to zero. |
| 60 | optimizer.shadow_copy(model) |
| 61 | checkpoint = tf.train.Checkpoint(model=model) |
| 62 | |
| 63 | # Changes model.value to 3, average value is still 0. |
| 64 | model.value.assign(3) |
| 65 | |
| 66 | # Checks model.value is 3 |
| 67 | self.assertEqual(model(0.), 3) |
| 68 | ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint) |
| 69 | |
| 70 | ema_action({}) |
| 71 | self.assertNotEmpty( |
| 72 | tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints'))) |
| 73 | |
| 74 | checkpoint.read( |
| 75 | tf.train.latest_checkpoint( |
| 76 | os.path.join(directory, 'ema_checkpoints'))) |
| 77 | |
| 78 | # Checks model.value is 0 after swapping. |
| 79 | self.assertEqual(model(0.), 0) |
| 80 | |
| 81 | # Raises an error for a normal optimizer. |
| 82 | with self.assertRaisesRegex(ValueError, |
| 83 | 'Optimizer has to be instance of.*'): |
| 84 | _ = actions.EMACheckpointing(directory, tf_keras.optimizers.SGD(), |
| 85 | checkpoint) |
| 86 | |
| 87 | @combinations.generate( |
| 88 | combinations.combine( |
nothing calls this directly
no test coverage detected