MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / run_epoch

Function run_epoch

tensorlayer/utils.py:612–654  ·  view source on GitHub ↗

Run a given non time-series network by the given cost function, test data, batch_size etc. for one epoch. Parameters ---------- network : TensorLayer Model the network to be trained. X : numpy.array The input of training data y : numpy.array The targe

(network, X, y, cost=None, acc=None, batch_size=100, shuffle=False)

Source from the content-addressed store, hash-verified

610
611
612def run_epoch(network, X, y, cost=None, acc=None, batch_size=100, shuffle=False):
613 """Run a given non time-series network by the given cost function, test data, batch_size etc.
614 for one epoch.
615
616 Parameters
617 ----------
618 network : TensorLayer Model
619 the network to be trained.
620 X : numpy.array
621 The input of training data
622 y : numpy.array
623 The target of training data
624 cost : TensorLayer or TensorFlow loss function
625 Metric for loss function, e.g tl.cost.cross_entropy.
626 acc : TensorFlow/numpy expression or None
627 Metric for accuracy or others. If None, would not print the information.
628 batch_size : int
629 The batch size for training and evaluating.
630 shuffle : boolean
631 Indicating whether to shuffle the dataset in training.
632
633 Returns
634 -------
635 loss_ep : Tensor. Average loss of this epoch. None if 'cost' is not given.
636 acc_ep : Tensor. Average accuracy(metric) of this epoch. None if 'acc' is not given.
637 n_step : int. Number of iterations taken in this epoch.
638 """
639 network.eval()
640 loss_ep = 0
641 acc_ep = 0
642 n_step = 0
643 for X_batch, y_batch in tl.iterate.minibatches(X, y, batch_size, shuffle=shuffle):
644 _loss, _acc = _run_step(network, X_batch, y_batch, cost=cost, acc=acc)
645 if cost is not None:
646 loss_ep += _loss
647 if acc is not None:
648 acc_ep += _acc
649 n_step += 1
650
651 loss_ep = loss_ep / n_step if cost is not None else None
652 acc_ep = acc_ep / n_step if acc is not None else None
653
654 return loss_ep, acc_ep, n_step
655
656
657@tf.function

Callers 2

fitFunction · 0.85
testFunction · 0.85

Calls 2

_run_stepFunction · 0.85
evalMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…