MCPcopy Index your code
hub / github.com/tensorlayer/TensorLayer / fit

Function fit

tensorlayer/utils.py:26–167  ·  view source on GitHub ↗

Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc. - MNIST example click `here `_. - In order to control the training details, the authors HIGHL

(
    network, train_op, cost, X_train, y_train, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None,
    y_val=None, eval_train=True, tensorboard_dir=None, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True,
    tensorboard_graph_vis=True
)

Source from the content-addressed store, hash-verified

24
25
26def fit(
27 network, train_op, cost, X_train, y_train, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None,
28 y_val=None, eval_train=True, tensorboard_dir=None, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True,
29 tensorboard_graph_vis=True
30):
31 """Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.
32
33 - MNIST example click `here <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_.
34 - In order to control the training details, the authors HIGHLY recommend ``tl.iterate`` see two MNIST examples `1 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mlp_dropout1.py>`_, `2 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mlp_dropout1.py>`_.
35
36 Parameters
37 ----------
38 network : TensorLayer Model
39 the network to be trained.
40 train_op : TensorFlow optimizer
41 The optimizer for training e.g. tf.optimizers.Adam().
42 cost : TensorLayer or TensorFlow loss function
43 Metric for loss function, e.g tl.cost.cross_entropy.
44 X_train : numpy.array
45 The input of training data
46 y_train : numpy.array
47 The target of training data
48 acc : TensorFlow/numpy expression or None
49 Metric for accuracy or others. If None, would not print the information.
50 batch_size : int
51 The batch size for training and evaluating.
52 n_epoch : int
53 The number of training epochs.
54 print_freq : int
55 Print the training information every ``print_freq`` epochs.
56 X_val : numpy.array or None
57 The input of validation data. If None, would not perform validation.
58 y_val : numpy.array or None
59 The target of validation data. If None, would not perform validation.
60 eval_train : boolean
61 Whether to evaluate the model during training.
62 If X_val and y_val are not None, it reflects whether to evaluate the model on training data.
63 tensorboard_dir : string
64 path to log dir, if set, summary data will be stored to the tensorboard_dir/ directory for visualization with tensorboard. (default None)
65 tensorboard_epoch_freq : int
66 How many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5).
67 tensorboard_weight_histograms : boolean
68 If True updates tensorboard data in the logs/ directory for visualization
69 of the weight histograms every tensorboard_epoch_freq epoch (default True).
70 tensorboard_graph_vis : boolean
71 If True stores the graph in the tensorboard summaries saved to log/ (default True).
72
73 Examples
74 --------
75 See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_
76
77 >>> tl.utils.fit(network, train_op=tf.optimizers.Adam(learning_rate=0.0001),
78 ... cost=tl.cost.cross_entropy, X_train=X_train, y_train=y_train, acc=acc,
79 ... batch_size=64, n_epoch=20, _val=X_val, y_val=y_val, eval_train=True)
80 >>> tl.utils.fit(network, train_op, cost, X_train, y_train,
81 ... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
82 ... X_val=X_val, y_val=y_val, eval_train=False, tensorboard=True)
83

Callers

nothing calls this directly

Calls 2

train_epochFunction · 0.85
run_epochFunction · 0.85

Tested by

no test coverage detected