MCPcopy Index your code
hub / github.com/tensorflow/models / test_ema_checkpointing

Method test_ema_checkpointing

official/core/actions_test.py:50–85  ·  view source on GitHub ↗
(self, distribution)

Source from the content-addressed store, hash-verified

48 strategy_combinations.one_device_strategy,
49 ],))
50 def test_ema_checkpointing(self, distribution):
51 with distribution.scope():
52 directory = self.create_tempdir()
53 model = TestModel()
54 optimizer = tf_keras.optimizers.SGD()
55 optimizer = optimization.ExponentialMovingAverage(
56 optimizer, trainable_weights_only=False)
57
58 # Creats average weights for the model variables. Average weights are
59 # initialized to zero.
60 optimizer.shadow_copy(model)
61 checkpoint = tf.train.Checkpoint(model=model)
62
63 # Changes model.value to 3, average value is still 0.
64 model.value.assign(3)
65
66 # Checks model.value is 3
67 self.assertEqual(model(0.), 3)
68 ema_action = actions.EMACheckpointing(directory, optimizer, checkpoint)
69
70 ema_action({})
71 self.assertNotEmpty(
72 tf.io.gfile.glob(os.path.join(directory, 'ema_checkpoints')))
73
74 checkpoint.read(
75 tf.train.latest_checkpoint(
76 os.path.join(directory, 'ema_checkpoints')))
77
78 # Checks model.value is 0 after swapping.
79 self.assertEqual(model(0.), 0)
80
81 # Raises an error for a normal optimizer.
82 with self.assertRaisesRegex(ValueError,
83 'Optimizer has to be instance of.*'):
84 _ = actions.EMACheckpointing(directory, tf_keras.optimizers.SGD(),
85 checkpoint)
86
87 @combinations.generate(
88 combinations.combine(

Callers

nothing calls this directly

Calls 5

shadow_copyMethod · 0.95
TestModelClass · 0.70
assignMethod · 0.45
joinMethod · 0.45
readMethod · 0.45

Tested by

no test coverage detected