MCPcopy
hub / github.com/cs230-stanford/cs230-code-examples / model_fn

Function model_fn

tensorflow/vision/model/model_fn.py:50–138  ·  view source on GitHub ↗

Model function defining the graph operations. Args: mode: (string) can be 'train' or 'eval' inputs: (dict) contains the inputs of the graph (features, labels...) this can be `tf.placeholder` or outputs of `tf.data` params: (Params) contains hyperparameter

(mode, inputs, params, reuse=False)

Source from the content-addressed store, hash-verified

48
49
50def model_fn(mode, inputs, params, reuse=False):
51 """Model function defining the graph operations.
52
53 Args:
54 mode: (string) can be 'train' or 'eval'
55 inputs: (dict) contains the inputs of the graph (features, labels...)
56 this can be `tf.placeholder` or outputs of `tf.data`
57 params: (Params) contains hyperparameters of the model (ex: `params.learning_rate`)
58 reuse: (bool) whether to reuse the weights
59
60 Returns:
61 model_spec: (dict) contains the graph operations or nodes needed for training / evaluation
62 """
63 is_training = (mode == 'train')
64 labels = inputs['labels']
65 labels = tf.cast(labels, tf.int64)
66
67 # -----------------------------------------------------------
68 # MODEL: define the layers of the model
69 with tf.variable_scope('model', reuse=reuse):
70 # Compute the output distribution of the model and the predictions
71 logits = build_model(is_training, inputs, params)
72 predictions = tf.argmax(logits, 1)
73
74 # Define loss and accuracy
75 loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
76 accuracy = tf.reduce_mean(tf.cast(tf.equal(labels, predictions), tf.float32))
77
78 # Define training step that minimizes the loss with the Adam optimizer
79 if is_training:
80 optimizer = tf.train.AdamOptimizer(params.learning_rate)
81 global_step = tf.train.get_or_create_global_step()
82 if params.use_batch_norm:
83 # Add a dependency to update the moving mean and variance for batch normalization
84 with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
85 train_op = optimizer.minimize(loss, global_step=global_step)
86 else:
87 train_op = optimizer.minimize(loss, global_step=global_step)
88
89
90 # -----------------------------------------------------------
91 # METRICS AND SUMMARIES
92 # Metrics for evaluation using tf.metrics (average over whole dataset)
93 with tf.variable_scope("metrics"):
94 metrics = {
95 'accuracy': tf.metrics.accuracy(labels=labels, predictions=tf.argmax(logits, 1)),
96 'loss': tf.metrics.mean(loss)
97 }
98
99 # Group the update ops for the tf.metrics
100 update_metrics_op = tf.group(*[op for _, op in metrics.values()])
101
102 # Get the op to reset the local variables used in tf.metrics
103 metric_variables = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics")
104 metrics_init_op = tf.variables_initializer(metric_variables)
105
106 # Summaries for training
107 tf.summary.scalar('loss', loss)

Callers 2

train.pyFile · 0.90
evaluate.pyFile · 0.90

Calls 1

build_modelFunction · 0.70

Tested by

no test coverage detected