MCPcopy
hub / github.com/ibab/tensorflow-wavenet / testCompareSimpleFast

Method testCompareSimpleFast

test/test_generation.py:50–72  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

48 self.assertTrue(np.all((proba >= 0) & (proba <= (128 - 1))))
49
50 def testCompareSimpleFast(self):
51 waveform = tf.placeholder(tf.int32)
52 np.random.seed(0)
53 data = np.random.randint(128, size=1000)
54 proba = self.net.predict_proba(waveform)
55 proba_fast = self.net.predict_proba_incremental(waveform)
56 with self.test_session() as sess:
57 sess.run(tf.global_variables_initializer())
58 sess.run(self.net.init_ops)
59 # Prime the incremental generation with all samples
60 # except the last one
61 for x in data[:-1]:
62 proba_fast_ = sess.run(
63 [proba_fast, self.net.push_ops],
64 feed_dict={waveform: x})
65
66 # Get the last sample from the incremental generator
67 proba_fast_ = sess.run(
68 proba_fast,
69 feed_dict={waveform: data[-1]})
70 # Get the sample from the simple generator
71 proba_ = sess.run(proba, feed_dict={waveform: data})
72 self.assertAllClose(proba_, proba_fast_)
73
74
75class TestGenerationBiases(TestGeneration):

Callers

nothing calls this directly

Calls 2

predict_probaMethod · 0.80

Tested by

no test coverage detected