Test against tfgan.eval.classifier_metrics
()
| 201 | |
| 202 | |
| 203 | def test_all(): |
| 204 | """ |
| 205 | Test against tfgan.eval.classifier_metrics |
| 206 | """ |
| 207 | |
| 208 | import tensorflow.compat.v1 as tf |
| 209 | import tensorflow_gan as tfgan |
| 210 | |
| 211 | rand = np.random.RandomState(1234) |
| 212 | logits = rand.randn(64, 1008) |
| 213 | asdf1, asdf2 = rand.randn(64, 2048), rand.rand(256, 2048) |
| 214 | with tf.Session() as sess: |
| 215 | assert np.allclose( |
| 216 | sess.run(tfgan.eval.classifier_score_from_logits(tf.convert_to_tensor(logits))), |
| 217 | classifier_score_from_logits(logits)) |
| 218 | assert np.allclose( |
| 219 | sess.run(tfgan.eval.frechet_classifier_distance_from_activations( |
| 220 | tf.convert_to_tensor(asdf1), tf.convert_to_tensor(asdf2))), |
| 221 | frechet_classifier_distance_from_activations(asdf1, asdf2)) |
| 222 | print('all ok') |
| 223 | |
| 224 | |
| 225 | if __name__ == '__main__': |
no test coverage detected