Compute logits of the model (output distribution) Args: is_training: (bool) whether we are training or not inputs: (dict) contains the inputs of the graph (features, labels...) this can be `tf.placeholder` or outputs of `tf.data` params: (Params) hyperpar
(is_training, inputs, params)
| 4 | |
| 5 | |
| 6 | def build_model(is_training, inputs, params): |
| 7 | """Compute logits of the model (output distribution) |
| 8 | |
| 9 | Args: |
| 10 | is_training: (bool) whether we are training or not |
| 11 | inputs: (dict) contains the inputs of the graph (features, labels...) |
| 12 | this can be `tf.placeholder` or outputs of `tf.data` |
| 13 | params: (Params) hyperparameters |
| 14 | |
| 15 | Returns: |
| 16 | output: (tf.Tensor) output of the model |
| 17 | """ |
| 18 | images = inputs['images'] |
| 19 | |
| 20 | assert images.get_shape().as_list() == [None, params.image_size, params.image_size, 3] |
| 21 | |
| 22 | out = images |
| 23 | # Define the number of channels of each convolution |
| 24 | # For each block, we do: 3x3 conv -> batch norm -> relu -> 2x2 maxpool |
| 25 | num_channels = params.num_channels |
| 26 | bn_momentum = params.bn_momentum |
| 27 | channels = [num_channels, num_channels * 2, num_channels * 4, num_channels * 8] |
| 28 | for i, c in enumerate(channels): |
| 29 | with tf.variable_scope('block_{}'.format(i+1)): |
| 30 | out = tf.layers.conv2d(out, c, 3, padding='same') |
| 31 | if params.use_batch_norm: |
| 32 | out = tf.layers.batch_normalization(out, momentum=bn_momentum, training=is_training) |
| 33 | out = tf.nn.relu(out) |
| 34 | out = tf.layers.max_pooling2d(out, 2, 2) |
| 35 | |
| 36 | assert out.get_shape().as_list() == [None, 4, 4, num_channels * 8] |
| 37 | |
| 38 | out = tf.reshape(out, [-1, 4 * 4 * num_channels * 8]) |
| 39 | with tf.variable_scope('fc_1'): |
| 40 | out = tf.layers.dense(out, num_channels * 8) |
| 41 | if params.use_batch_norm: |
| 42 | out = tf.layers.batch_normalization(out, momentum=bn_momentum, training=is_training) |
| 43 | out = tf.nn.relu(out) |
| 44 | with tf.variable_scope('fc_2'): |
| 45 | logits = tf.layers.dense(out, params.num_labels) |
| 46 | |
| 47 | return logits |
| 48 | |
| 49 | |
| 50 | def model_fn(mode, inputs, params, reuse=False): |