MCPcopy
hub / github.com/naturomics/CapsNet-Tensorflow / load_mnist

Function load_mnist

utils.py:9–36  ·  view source on GitHub ↗
(path, is_training)

Source from the content-addressed store, hash-verified

7
8
9def load_mnist(path, is_training):
10 fd = open(os.path.join(cfg.dataset, 'train-images-idx3-ubyte'))
11 loaded = np.fromfile(file=fd, dtype=np.uint8)
12 trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)
13
14 fd = open(os.path.join(cfg.dataset, 'train-labels-idx1-ubyte'))
15 loaded = np.fromfile(file=fd, dtype=np.uint8)
16 trY = loaded[8:].reshape((60000)).astype(np.float)
17
18 fd = open(os.path.join(cfg.dataset, 't10k-images-idx3-ubyte'))
19 loaded = np.fromfile(file=fd, dtype=np.uint8)
20 teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)
21
22 fd = open(os.path.join(cfg.dataset, 't10k-labels-idx1-ubyte'))
23 loaded = np.fromfile(file=fd, dtype=np.uint8)
24 teY = loaded[8:].reshape((10000)).astype(np.float)
25
26 # normalization and convert to a tensor [60000, 28, 28, 1]
27 trX = tf.convert_to_tensor(trX / 255., tf.float32)
28
29 # => [num_samples, 10]
30 trY = tf.one_hot(trY, depth=10, axis=1, dtype=tf.float32)
31 teY = tf.one_hot(teY, depth=10, axis=1, dtype=tf.float32)
32
33 if is_training:
34 return trX, trY
35 else:
36 return teX / 255., teY
37
38
39def get_batch_data():

Callers 3

eval.pyFile · 0.90
get_batch_dataFunction · 0.85
utils.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected