Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc. - MNIST example click `here `_. - In order to control the training details, the authors HIGHL
(
network, train_op, cost, X_train, y_train, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None,
y_val=None, eval_train=True, tensorboard_dir=None, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True,
tensorboard_graph_vis=True
)
| 24 | |
| 25 | |
| 26 | def fit( |
| 27 | network, train_op, cost, X_train, y_train, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None, |
| 28 | y_val=None, eval_train=True, tensorboard_dir=None, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True, |
| 29 | tensorboard_graph_vis=True |
| 30 | ): |
| 31 | """Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc. |
| 32 | |
| 33 | - MNIST example click `here <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_. |
| 34 | - In order to control the training details, the authors HIGHLY recommend ``tl.iterate`` see two MNIST examples `1 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mlp_dropout1.py>`_, `2 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mlp_dropout1.py>`_. |
| 35 | |
| 36 | Parameters |
| 37 | ---------- |
| 38 | network : TensorLayer Model |
| 39 | the network to be trained. |
| 40 | train_op : TensorFlow optimizer |
| 41 | The optimizer for training e.g. tf.optimizers.Adam(). |
| 42 | cost : TensorLayer or TensorFlow loss function |
| 43 | Metric for loss function, e.g tl.cost.cross_entropy. |
| 44 | X_train : numpy.array |
| 45 | The input of training data |
| 46 | y_train : numpy.array |
| 47 | The target of training data |
| 48 | acc : TensorFlow/numpy expression or None |
| 49 | Metric for accuracy or others. If None, would not print the information. |
| 50 | batch_size : int |
| 51 | The batch size for training and evaluating. |
| 52 | n_epoch : int |
| 53 | The number of training epochs. |
| 54 | print_freq : int |
| 55 | Print the training information every ``print_freq`` epochs. |
| 56 | X_val : numpy.array or None |
| 57 | The input of validation data. If None, would not perform validation. |
| 58 | y_val : numpy.array or None |
| 59 | The target of validation data. If None, would not perform validation. |
| 60 | eval_train : boolean |
| 61 | Whether to evaluate the model during training. |
| 62 | If X_val and y_val are not None, it reflects whether to evaluate the model on training data. |
| 63 | tensorboard_dir : string |
| 64 | path to log dir, if set, summary data will be stored to the tensorboard_dir/ directory for visualization with tensorboard. (default None) |
| 65 | tensorboard_epoch_freq : int |
| 66 | How many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5). |
| 67 | tensorboard_weight_histograms : boolean |
| 68 | If True updates tensorboard data in the logs/ directory for visualization |
| 69 | of the weight histograms every tensorboard_epoch_freq epoch (default True). |
| 70 | tensorboard_graph_vis : boolean |
| 71 | If True stores the graph in the tensorboard summaries saved to log/ (default True). |
| 72 | |
| 73 | Examples |
| 74 | -------- |
| 75 | See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_ |
| 76 | |
| 77 | >>> tl.utils.fit(network, train_op=tf.optimizers.Adam(learning_rate=0.0001), |
| 78 | ... cost=tl.cost.cross_entropy, X_train=X_train, y_train=y_train, acc=acc, |
| 79 | ... batch_size=64, n_epoch=20, _val=X_val, y_val=y_val, eval_train=True) |
| 80 | >>> tl.utils.fit(network, train_op, cost, X_train, y_train, |
| 81 | ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, |
| 82 | ... X_val=X_val, y_val=y_val, eval_train=False, tensorboard=True) |
| 83 |
nothing calls this directly
no test coverage detected