MCPcopy
hub / github.com/uber/petastorm / train_and_test

Function train_and_test

examples/mnist/tf_example.py:32–103  ·  view source on GitHub ↗

Train a model for training iterations with a batch size batch_size, printing accuracy every log_interval. :param dataset_url: The MNIST dataset url. :param training_iterations: The training iterations to train for. :param batch_size: The batch size for training. :param evaluatio

(dataset_url, training_iterations, batch_size, evaluation_interval)

Source from the content-addressed store, hash-verified

30
31
32def train_and_test(dataset_url, training_iterations, batch_size, evaluation_interval):
33 """
34 Train a model for training iterations with a batch size batch_size, printing accuracy every log_interval.
35 :param dataset_url: The MNIST dataset url.
36 :param training_iterations: The training iterations to train for.
37 :param batch_size: The batch size for training.
38 :param evaluation_interval: The interval used to print the accuracy.
39 :return:
40 """
41 with make_reader(os.path.join(dataset_url, 'train'), num_epochs=None) as train_reader:
42 with make_reader(os.path.join(dataset_url, 'test'), num_epochs=None) as test_reader:
43 train_readout = tf_tensors(train_reader)
44 train_image = tf.cast(tf.reshape(train_readout.image, [784]), tf.float32)
45 train_label = train_readout.digit
46 batch_image, batch_label = tf.train.batch(
47 [train_image, train_label], batch_size=batch_size
48 )
49
50 W = tf.Variable(tf.zeros([784, 10]))
51 b = tf.Variable(tf.zeros([10]))
52 y = tf.matmul(batch_image, W) + b
53
54 # The raw formulation of cross-entropy,
55 #
56 # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
57 # reduction_indices=[1]))
58 #
59 # can be numerically unstable.
60 #
61 # So here we use tf.losses.sparse_softmax_cross_entropy on the raw
62 # outputs of 'y', and then average across the batch.
63 cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=batch_label, logits=y)
64 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
65
66 correct_prediction = tf.equal(tf.argmax(y, 1), batch_label)
67 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
68
69 test_readout = tf_tensors(test_reader)
70 test_image = tf.cast(tf.reshape(test_readout.image, [784]), tf.float32)
71 test_label = test_readout.digit
72 test_batch_image, test_batch_label = tf.train.batch(
73 [test_image, test_label], batch_size=batch_size
74 )
75
76 # Train
77 print('Training model for {0} training iterations with batch size {1} and evaluation interval {2}'.format(
78 training_iterations, batch_size, evaluation_interval
79 ))
80 with tf.Session() as sess:
81 sess.run([
82 tf.local_variables_initializer(),
83 tf.global_variables_initializer(),
84 ])
85 coord = tf.train.Coordinator()
86 threads = tf.train.start_queue_runners(sess=sess, coord=coord)
87 try:
88 for i in range(training_iterations):
89 if coord.should_stop():

Callers 1

mainFunction · 0.85

Calls 4

make_readerFunction · 0.90
tf_tensorsFunction · 0.90
runMethod · 0.80
joinMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…