MCPcopy
hub / github.com/tensorflow/tfjs / testInference

Method testInference

tfjs-inference/python/inference_test.py:33–65  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

31class InferenceTest(tf.test.TestCase):
32
33 def testInference(self):
34 backends = ['cpu', 'wasm']
35 for backend in backends:
36 binary_path = os.path.join('../binaries', 'tfjs-inference-linux')
37 model_path = os.path.join('../test_data', 'model.json')
38 test_data_dir = os.path.join('../test_data')
39 tmp_dir = tempfile.mkdtemp()
40
41 inference.predict(binary_path, model_path, test_data_dir, tmp_dir, backend=backend)
42
43 with open(os.path.join(tmp_dir, 'data.json'), 'rt') as f:
44 ys_values = json.load(f)
45
46 # The output is a list of tensor data in the form of dict.
47 # Example output:
48 # [{"0":0.7567615509033203,"1":-0.18349379301071167,"2":0.7567615509033203,"3":-0.18349379301071167}]
49 ys_values = [list(y.values()) for y in ys_values]
50
51 with open(os.path.join(tmp_dir, 'shape.json'), 'rt') as f:
52 ys_shapes = json.load(f)
53
54 with open(os.path.join(tmp_dir, 'dtype.json'), 'rt') as f:
55 ys_dtypes = json.load(f)
56
57 self.assertAllClose(ys_values[0], [
58 0.7567615509033203, -0.18349379301071167, 0.7567615509033203,
59 -0.18349379301071167
60 ])
61 self.assertAllEqual(ys_shapes[0], [2, 2])
62 self.assertEqual(ys_dtypes[0], 'float32')
63 self.assertFalse(os.path.exists(os.path.join(tmp_dir, 'name.json')))
64 # Cleanup tmp dir.
65 shutil.rmtree(tmp_dir)
66
67 # Todo(linazhao): Add a test model that outputs multiple tensors.
68 def testInferenceWithOutputNameFile(self):

Callers

nothing calls this directly

Calls 3

joinMethod · 0.80
predictMethod · 0.65
loadMethod · 0.45

Tested by

no test coverage detected