Train a model for training iterations with a batch size batch_size, printing accuracy every log_interval. :param dataset_url: The MNIST dataset url. :param training_iterations: The training iterations to train for. :param batch_size: The batch size for training. :param evaluatio
(dataset_url, training_iterations, batch_size, evaluation_interval)
| 30 | |
| 31 | |
| 32 | def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval): |
| 33 | """ |
| 34 | Train a model for training iterations with a batch size batch_size, printing accuracy every log_interval. |
| 35 | :param dataset_url: The MNIST dataset url. |
| 36 | :param training_iterations: The training iterations to train for. |
| 37 | :param batch_size: The batch size for training. |
| 38 | :param evaluation_interval: The interval used to print the accuracy. |
| 39 | :return: |
| 40 | """ |
| 41 | with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader: |
| 42 | with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader: |
| 43 | train_readout = tf_tensors(train_reader) |
| 44 | train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32) |
| 45 | train_label = train_readout.digit |
| 46 | batch_image, batch_label = tf.train.batch( |
| 47 | [train_image, train_label], batch_size=batch_size |
| 48 | ) |
| 49 | |
| 50 | W = tf.Variable(tf.zeros([784, 10])) |
| 51 | b = tf.Variable(tf.zeros([10])) |
| 52 | y = tf.matmul(batch_image, W) + b |
| 53 | |
| 54 | # The raw formulation of cross-entropy, |
| 55 | # |
| 56 | # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), |
| 57 | # reduction_indices=[1])) |
| 58 | # |
| 59 | # can be numerically unstable. |
| 60 | # |
| 61 | # So here we use tf.losses.sparse_softmax_cross_entropy on the raw |
| 62 | # outputs of 'y', and then average across the batch. |
| 63 | cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=batch_label, logits=y) |
| 64 | train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) |
| 65 | |
| 66 | correct_prediction = tf.equal(tf.argmax(y, 1), batch_label) |
| 67 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) |
| 68 | |
| 69 | test_readout = tf_tensors(test_reader) |
| 70 | test_image = tf.cast(tf.reshape(test_readout.image, [784]), tf.float32) |
| 71 | test_label = test_readout.digit |
| 72 | test_batch_image, test_batch_label = tf.train.batch( |
| 73 | [test_image, test_label], batch_size=batch_size |
| 74 | ) |
| 75 | |
| 76 | # Train |
| 77 | print('Training model for {0} training iterations with batch size {1} and evaluation interval {2}'.format( |
| 78 | training_iterations, batch_size, evaluation_interval |
| 79 | )) |
| 80 | with tf.Session() as sess: |
| 81 | sess.run([ |
| 82 | tf.local_variables_initializer(), |
| 83 | tf.global_variables_initializer(), |
| 84 | ]) |
| 85 | coord = tf.train.Coordinator() |
| 86 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
| 87 | try: |
| 88 | for i in range(training_iterations): |
| 89 | if coord.should_stop(): |
no test coverage detected
searching dependent graphs…