MCPcopy Index your code
hub / github.com/naturomics/CapsNet-Tensorflow / build_arch

Method build_arch

capsNet.py:29–86  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

27 tf.logging.info('Seting up the main structure')
28
29 def build_arch(self):
30 with tf.variable_scope('Conv1_layer'):
31 # Conv1, [batch_size, 20, 20, 256]
32 conv1 = tf.contrib.layers.conv2d(self.X, num_outputs=256,
33 kernel_size=9, stride=1,
34 padding='VALID')
35 assert conv1.get_shape() == [cfg.batch_size, 20, 20, 256]
36
37 # TODO: Rewrite the 'CapsConv' class as a function, the capsLay
38 # function should be encapsulated into tow function, one like conv2d
39 # and another is fully_connected in Tensorflow.
40 # Primary Capsules, [batch_size, 1152, 8, 1]
41 with tf.variable_scope('PrimaryCaps_layer'):
42 primaryCaps = CapsConv(num_units=8, with_routing=False)
43 caps1 = primaryCaps(conv1, num_outputs=32, kernel_size=9, stride=2)
44 assert caps1.get_shape() == [cfg.batch_size, 1152, 8, 1]
45
46 # DigitCaps layer, [batch_size, 10, 16, 1]
47 with tf.variable_scope('DigitCaps_layer'):
48 digitCaps = CapsConv(num_units=16, with_routing=True)
49 self.caps2 = digitCaps(caps1, num_outputs=10)
50
51 # Decoder structure in Fig. 2
52 # 1. Do masking, how:
53 with tf.variable_scope('Masking'):
54 # a). calc ||v_c||, then do softmax(||v_c||)
55 # [batch_size, 10, 16, 1] => [batch_size, 10, 1, 1]
56 self.v_length = tf.sqrt(tf.reduce_sum(tf.square(self.caps2),
57 axis=2, keep_dims=True))
58 self.softmax_v = tf.nn.softmax(self.v_length, dim=1)
59 assert self.softmax_v.get_shape() == [cfg.batch_size, 10, 1, 1]
60
61 # b). pick out the index of max softmax val of the 10 caps
62 # [batch_size, 10, 1, 1] => [batch_size] (index)
63 argmax_idx = tf.argmax(self.softmax_v, axis=1, output_type=tf.int32)
64 assert argmax_idx.get_shape() == [cfg.batch_size, 1, 1]
65
66 # c). indexing
67 # It's not easy to understand the indexing process with argmax_idx
68 # as we are 3-dim animal
69 masked_v = []
70 argmax_idx = tf.reshape(argmax_idx, shape=(cfg.batch_size, ))
71 for batch_size in range(cfg.batch_size):
72 v = self.caps2[batch_size][argmax_idx[batch_size], :]
73 masked_v.append(tf.reshape(v, shape=(1, 1, 16, 1)))
74
75 self.masked_v = tf.concat(masked_v, axis=0)
76 assert self.masked_v.get_shape() == [cfg.batch_size, 1, 16, 1]
77
78 # 2. Reconstructe the MNIST images with 3 FC layers
79 # [batch_size, 1, 16, 1] => [batch_size, 16] => [batch_size, 512]
80 with tf.variable_scope('Decoder'):
81 vector_j = tf.reshape(self.masked_v, shape=(cfg.batch_size, -1))
82 fc1 = tf.contrib.layers.fully_connected(vector_j, num_outputs=512)
83 assert fc1.get_shape() == [cfg.batch_size, 512]
84 fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
85 assert fc2.get_shape() == [cfg.batch_size, 1024]
86 self.decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid)

Callers 1

__init__Method · 0.95

Calls 1

CapsConvClass · 0.90

Tested by

no test coverage detected