MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / get_cost

Function get_cost

tensorpack/contrib/keras.py:166–213  ·  view source on GitHub ↗
(*inputs)

Source from the content-addressed store, hash-verified

164 nr_inputs = len(input_signature)
165
166 def get_cost(*inputs):
167 ctx = get_current_tower_context()
168 input_tensors = list(inputs[:nr_inputs])
169 target_tensors = list(inputs[nr_inputs:])
170 # TODO mapping between target tensors & output tensors
171
172 outputs = model_caller(*input_tensors)
173
174 if isinstance(outputs, tf.Tensor):
175 outputs = [outputs]
176 assert len(outputs) == len(target_tensors), \
177 "len({}) != len({})".format(str(outputs), str(target_tensors))
178 assert len(outputs) == len(loss), \
179 "len({}) != len({})".format(str(outputs), str(loss))
180
181 loss_tensors = []
182 for idx, loss_name in enumerate(loss):
183 with cached_name_scope('keras_loss', top_level=False):
184 loss_fn = keras.losses.get(loss_name)
185 curr_loss = loss_fn(target_tensors[idx], outputs[idx])
186 curr_loss = tf.reduce_mean(curr_loss, name=loss_name)
187 _check_name(curr_loss, loss_name)
188 loss_tensors.append(curr_loss)
189
190 loss_reg = regularize_cost_from_collection()
191 if loss_reg is not None:
192 total_loss = tf.add_n(loss_tensors + [loss_reg], name=TOTAL_LOSS_NAME)
193 add_moving_summary(loss_reg, total_loss, *loss_tensors)
194 else:
195 total_loss = tf.add_n(loss_tensors, name=TOTAL_LOSS_NAME)
196 add_moving_summary(total_loss, *loss_tensors)
197
198 if metrics and (ctx.is_main_training_tower or not ctx.is_training):
199 # for list: one metric for each output
200 metric_tensors = []
201 for oid, metric_name in enumerate(metrics):
202 output_tensor = outputs[oid]
203 target_tensor = target_tensors[oid] # TODO may not have the same mapping?
204 with cached_name_scope('keras_metric', top_level=False):
205 metric_fn = keras.metrics.get(metric_name)
206 metric_tensor = metric_fn(target_tensor, output_tensor)
207 metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
208 _check_name(metric_tensor, metric_name)
209 # check name conflict here
210 metric_tensors.append(metric_tensor)
211 add_moving_summary(*metric_tensors)
212
213 return total_loss
214
215 trainer.setup_graph(
216 input_signature + target_signature,

Callers

nothing calls this directly

Calls 8

cached_name_scopeFunction · 0.85
_check_nameFunction · 0.85
add_moving_summaryFunction · 0.85
formatMethod · 0.80
appendMethod · 0.80
getMethod · 0.45

Tested by

no test coverage detected