MCPcopy
hub / github.com/cs230-stanford/cs230-code-examples / build_model

Function build_model

tensorflow/vision/model/model_fn.py:6–47  ·  view source on GitHub ↗

Compute logits of the model (output distribution) Args: is_training: (bool) whether we are training or not inputs: (dict) contains the inputs of the graph (features, labels...) this can be `tf.placeholder` or outputs of `tf.data` params: (Params) hyperpar

(is_training, inputs, params)

Source from the content-addressed store, hash-verified

4
5
6def build_model(is_training, inputs, params):
7 """Compute logits of the model (output distribution)
8
9 Args:
10 is_training: (bool) whether we are training or not
11 inputs: (dict) contains the inputs of the graph (features, labels...)
12 this can be `tf.placeholder` or outputs of `tf.data`
13 params: (Params) hyperparameters
14
15 Returns:
16 output: (tf.Tensor) output of the model
17 """
18 images = inputs['images']
19
20 assert images.get_shape().as_list() == [None, params.image_size, params.image_size, 3]
21
22 out = images
23 # Define the number of channels of each convolution
24 # For each block, we do: 3x3 conv -> batch norm -> relu -> 2x2 maxpool
25 num_channels = params.num_channels
26 bn_momentum = params.bn_momentum
27 channels = [num_channels, num_channels * 2, num_channels * 4, num_channels * 8]
28 for i, c in enumerate(channels):
29 with tf.variable_scope('block_{}'.format(i+1)):
30 out = tf.layers.conv2d(out, c, 3, padding='same')
31 if params.use_batch_norm:
32 out = tf.layers.batch_normalization(out, momentum=bn_momentum, training=is_training)
33 out = tf.nn.relu(out)
34 out = tf.layers.max_pooling2d(out, 2, 2)
35
36 assert out.get_shape().as_list() == [None, 4, 4, num_channels * 8]
37
38 out = tf.reshape(out, [-1, 4 * 4 * num_channels * 8])
39 with tf.variable_scope('fc_1'):
40 out = tf.layers.dense(out, num_channels * 8)
41 if params.use_batch_norm:
42 out = tf.layers.batch_normalization(out, momentum=bn_momentum, training=is_training)
43 out = tf.nn.relu(out)
44 with tf.variable_scope('fc_2'):
45 logits = tf.layers.dense(out, params.num_labels)
46
47 return logits
48
49
50def model_fn(mode, inputs, params, reuse=False):

Callers 1

model_fnFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected