MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / train_epoch

Function train_epoch

tensorlayer/utils.py:562–609  ·  view source on GitHub ↗

Training a given non time-series network by the given cost function, training data, batch_size etc. for one epoch. Parameters ---------- network : TensorLayer Model the network to be trained. X : numpy.array The input of training data y : numpy.array

(
    network, X, y, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None, batch_size=100, shuffle=True
)

Source from the content-addressed store, hash-verified

560
561
562def train_epoch(
563 network, X, y, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None, batch_size=100, shuffle=True
564):
565 """Training a given non time-series network by the given cost function, training data, batch_size etc.
566 for one epoch.
567
568 Parameters
569 ----------
570 network : TensorLayer Model
571 the network to be trained.
572 X : numpy.array
573 The input of training data
574 y : numpy.array
575 The target of training data
576 cost : TensorLayer or TensorFlow loss function
577 Metric for loss function, e.g tl.cost.cross_entropy.
578 train_op : TensorFlow optimizer
579 The optimizer for training e.g. tf.optimizers.Adam().
580 acc : TensorFlow/numpy expression or None
581 Metric for accuracy or others. If None, would not print the information.
582 batch_size : int
583 The batch size for training and evaluating.
584 shuffle : boolean
585 Indicating whether to shuffle the dataset in training.
586
587 Returns
588 -------
589 loss_ep : Tensor. Average loss of this epoch.
590 acc_ep : Tensor or None. Average accuracy(metric) of this epoch. None if acc is not given.
591 n_step : int. Number of iterations taken in this epoch.
592
593 """
594 network.train()
595 loss_ep = 0
596 acc_ep = 0
597 n_step = 0
598 for X_batch, y_batch in tl.iterate.minibatches(X, y, batch_size, shuffle=shuffle):
599 _loss, _acc = _train_step(network, X_batch, y_batch, cost=cost, train_op=train_op, acc=acc)
600
601 loss_ep += _loss
602 if acc is not None:
603 acc_ep += _acc
604 n_step += 1
605
606 loss_ep = loss_ep / n_step
607 acc_ep = acc_ep / n_step if acc is not None else None
608
609 return loss_ep, acc_ep, n_step
610
611
612def run_epoch(network, X, y, cost=None, acc=None, batch_size=100, shuffle=False):

Callers 1

fitFunction · 0.85

Calls 2

_train_stepFunction · 0.70
trainMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…