Test a given non time-series network by the given test data and metric. Parameters ---------- network : TensorLayer Model The network. acc : TensorFlow/numpy expression or None Metric for accuracy or others. - If None, would not print the information
(network, acc, X_test, y_test, batch_size, cost=None)
| 168 | |
| 169 | |
| 170 | def test(network, acc, X_test, y_test, batch_size, cost=None): |
| 171 | """ |
| 172 | Test a given non time-series network by the given test data and metric. |
| 173 | |
| 174 | Parameters |
| 175 | ---------- |
| 176 | network : TensorLayer Model |
| 177 | The network. |
| 178 | acc : TensorFlow/numpy expression or None |
| 179 | Metric for accuracy or others. |
| 180 | - If None, would not print the information. |
| 181 | X_test : numpy.array |
| 182 | The input of testing data. |
| 183 | y_test : numpy array |
| 184 | The target of testing data |
| 185 | batch_size : int or None |
| 186 | The batch size for testing, when dataset is large, we should use minibatche for testing; |
| 187 | if dataset is small, we can set it to None. |
| 188 | cost : TensorLayer or TensorFlow loss function |
| 189 | Metric for loss function, e.g tl.cost.cross_entropy. If None, would not print the information. |
| 190 | |
| 191 | Examples |
| 192 | -------- |
| 193 | See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_ |
| 194 | |
| 195 | >>> def acc(_logits, y_batch): |
| 196 | ... return np.mean(np.equal(np.argmax(_logits, 1), y_batch)) |
| 197 | >>> tl.utils.test(network, acc, X_test, y_test, batch_size=None, cost=tl.cost.cross_entropy) |
| 198 | |
| 199 | """ |
| 200 | tl.logging.info('Start testing the network ...') |
| 201 | network.eval() |
| 202 | if batch_size is None: |
| 203 | y_pred = network(X_test) |
| 204 | if cost is not None: |
| 205 | test_loss = cost(y_pred, y_test) |
| 206 | # tl.logging.info(" test loss: %f" % test_loss) |
| 207 | test_acc = acc(y_pred, y_test) |
| 208 | # tl.logging.info(" test acc: %f" % (test_acc / test_acc)) |
| 209 | return test_acc |
| 210 | else: |
| 211 | test_loss, test_acc, n_batch = run_epoch( |
| 212 | network, X_test, y_test, cost=cost, acc=acc, batch_size=batch_size, shuffle=False |
| 213 | ) |
| 214 | if cost is not None: |
| 215 | tl.logging.info(" test loss: %f" % test_loss) |
| 216 | tl.logging.info(" test acc: %f" % test_acc) |
| 217 | return test_acc |
| 218 | |
| 219 | |
| 220 | def predict(network, X, batch_size=None): |