MCPcopy
hub / github.com/hojonathanho/diffusion / test_all

Function test_all

diffusion_tf/tpu_utils/classifier_metrics_numpy.py:203–222  ·  view source on GitHub ↗

Test against tfgan.eval.classifier_metrics

()

Source from the content-addressed store, hash-verified

201
202
203def test_all():
204 """
205 Test against tfgan.eval.classifier_metrics
206 """
207
208 import tensorflow.compat.v1 as tf
209 import tensorflow_gan as tfgan
210
211 rand = np.random.RandomState(1234)
212 logits = rand.randn(64, 1008)
213 asdf1, asdf2 = rand.randn(64, 2048), rand.rand(256, 2048)
214 with tf.Session() as sess:
215 assert np.allclose(
216 sess.run(tfgan.eval.classifier_score_from_logits(tf.convert_to_tensor(logits))),
217 classifier_score_from_logits(logits))
218 assert np.allclose(
219 sess.run(tfgan.eval.frechet_classifier_distance_from_activations(
220 tf.convert_to_tensor(asdf1), tf.convert_to_tensor(asdf2))),
221 frechet_classifier_distance_from_activations(asdf1, asdf2))
222 print('all ok')
223
224
225if __name__ == '__main__':

Callers 1

Calls 3

runMethod · 0.45

Tested by

no test coverage detected