Run a given non time-series network by the given cost function, test data, batch_size etc. for one epoch. Parameters ---------- network : TensorLayer Model the network to be trained. X : numpy.array The input of training data y : numpy.array The targe
(network, X, y, cost=None, acc=None, batch_size=100, shuffle=False)
| 610 | |
| 611 | |
| 612 | def run_epoch(network, X, y, cost=None, acc=None, batch_size=100, shuffle=False): |
| 613 | """Run a given non time-series network by the given cost function, test data, batch_size etc. |
| 614 | for one epoch. |
| 615 | |
| 616 | Parameters |
| 617 | ---------- |
| 618 | network : TensorLayer Model |
| 619 | the network to be trained. |
| 620 | X : numpy.array |
| 621 | The input of training data |
| 622 | y : numpy.array |
| 623 | The target of training data |
| 624 | cost : TensorLayer or TensorFlow loss function |
| 625 | Metric for loss function, e.g tl.cost.cross_entropy. |
| 626 | acc : TensorFlow/numpy expression or None |
| 627 | Metric for accuracy or others. If None, would not print the information. |
| 628 | batch_size : int |
| 629 | The batch size for training and evaluating. |
| 630 | shuffle : boolean |
| 631 | Indicating whether to shuffle the dataset in training. |
| 632 | |
| 633 | Returns |
| 634 | ------- |
| 635 | loss_ep : Tensor. Average loss of this epoch. None if 'cost' is not given. |
| 636 | acc_ep : Tensor. Average accuracy(metric) of this epoch. None if 'acc' is not given. |
| 637 | n_step : int. Number of iterations taken in this epoch. |
| 638 | """ |
| 639 | network.eval() |
| 640 | loss_ep = 0 |
| 641 | acc_ep = 0 |
| 642 | n_step = 0 |
| 643 | for X_batch, y_batch in tl.iterate.minibatches(X, y, batch_size, shuffle=shuffle): |
| 644 | _loss, _acc = _run_step(network, X_batch, y_batch, cost=cost, acc=acc) |
| 645 | if cost is not None: |
| 646 | loss_ep += _loss |
| 647 | if acc is not None: |
| 648 | acc_ep += _acc |
| 649 | n_step += 1 |
| 650 | |
| 651 | loss_ep = loss_ep / n_step if cost is not None else None |
| 652 | acc_ep = acc_ep / n_step if acc is not None else None |
| 653 | |
| 654 | return loss_ep, acc_ep, n_step |
| 655 | |
| 656 | |
| 657 | @tf.function |