MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / test

Function test

tensorlayer/utils.py:170–217  ·  view source on GitHub ↗

Test a given non time-series network by the given test data and metric. Parameters ---------- network : TensorLayer Model The network. acc : TensorFlow/numpy expression or None Metric for accuracy or others. - If None, would not print the information

(network, acc, X_test, y_test, batch_size, cost=None)

Source from the content-addressed store, hash-verified

168
169
170def test(network, acc, X_test, y_test, batch_size, cost=None):
171 """
172 Test a given non time-series network by the given test data and metric.
173
174 Parameters
175 ----------
176 network : TensorLayer Model
177 The network.
178 acc : TensorFlow/numpy expression or None
179 Metric for accuracy or others.
180 - If None, would not print the information.
181 X_test : numpy.array
182 The input of testing data.
183 y_test : numpy array
184 The target of testing data
185 batch_size : int or None
186 The batch size for testing, when dataset is large, we should use minibatche for testing;
187 if dataset is small, we can set it to None.
188 cost : TensorLayer or TensorFlow loss function
189 Metric for loss function, e.g tl.cost.cross_entropy. If None, would not print the information.
190
191 Examples
192 --------
193 See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_
194
195 >>> def acc(_logits, y_batch):
196 ... return np.mean(np.equal(np.argmax(_logits, 1), y_batch))
197 >>> tl.utils.test(network, acc, X_test, y_test, batch_size=None, cost=tl.cost.cross_entropy)
198
199 """
200 tl.logging.info('Start testing the network ...')
201 network.eval()
202 if batch_size is None:
203 y_pred = network(X_test)
204 if cost is not None:
205 test_loss = cost(y_pred, y_test)
206 # tl.logging.info(" test loss: %f" % test_loss)
207 test_acc = acc(y_pred, y_test)
208 # tl.logging.info(" test acc: %f" % (test_acc / test_acc))
209 return test_acc
210 else:
211 test_loss, test_acc, n_batch = run_epoch(
212 network, X_test, y_test, cost=cost, acc=acc, batch_size=batch_size, shuffle=False
213 )
214 if cost is not None:
215 tl.logging.info(" test loss: %f" % test_loss)
216 tl.logging.info(" test acc: %f" % test_acc)
217 return test_acc
218
219
220def predict(network, X, batch_size=None):

Callers

nothing calls this directly

Calls 3

run_epochFunction · 0.85
evalMethod · 0.80
accFunction · 0.50

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…