(*inputs)
| 164 | nr_inputs = len(input_signature) |
| 165 | |
| 166 | def get_cost(*inputs): |
| 167 | ctx = get_current_tower_context() |
| 168 | input_tensors = list(inputs[:nr_inputs]) |
| 169 | target_tensors = list(inputs[nr_inputs:]) |
| 170 | # TODO mapping between target tensors & output tensors |
| 171 | |
| 172 | outputs = model_caller(*input_tensors) |
| 173 | |
| 174 | if isinstance(outputs, tf.Tensor): |
| 175 | outputs = [outputs] |
| 176 | assert len(outputs) == len(target_tensors), \ |
| 177 | "len({}) != len({})".format(str(outputs), str(target_tensors)) |
| 178 | assert len(outputs) == len(loss), \ |
| 179 | "len({}) != len({})".format(str(outputs), str(loss)) |
| 180 | |
| 181 | loss_tensors = [] |
| 182 | for idx, loss_name in enumerate(loss): |
| 183 | with cached_name_scope('keras_loss', top_level=False): |
| 184 | loss_fn = keras.losses.get(loss_name) |
| 185 | curr_loss = loss_fn(target_tensors[idx], outputs[idx]) |
| 186 | curr_loss = tf.reduce_mean(curr_loss, name=loss_name) |
| 187 | _check_name(curr_loss, loss_name) |
| 188 | loss_tensors.append(curr_loss) |
| 189 | |
| 190 | loss_reg = regularize_cost_from_collection() |
| 191 | if loss_reg is not None: |
| 192 | total_loss = tf.add_n(loss_tensors + [loss_reg], name=TOTAL_LOSS_NAME) |
| 193 | add_moving_summary(loss_reg, total_loss, *loss_tensors) |
| 194 | else: |
| 195 | total_loss = tf.add_n(loss_tensors, name=TOTAL_LOSS_NAME) |
| 196 | add_moving_summary(total_loss, *loss_tensors) |
| 197 | |
| 198 | if metrics and (ctx.is_main_training_tower or not ctx.is_training): |
| 199 | # for list: one metric for each output |
| 200 | metric_tensors = [] |
| 201 | for oid, metric_name in enumerate(metrics): |
| 202 | output_tensor = outputs[oid] |
| 203 | target_tensor = target_tensors[oid] # TODO may not have the same mapping? |
| 204 | with cached_name_scope('keras_metric', top_level=False): |
| 205 | metric_fn = keras.metrics.get(metric_name) |
| 206 | metric_tensor = metric_fn(target_tensor, output_tensor) |
| 207 | metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name) |
| 208 | _check_name(metric_tensor, metric_name) |
| 209 | # check name conflict here |
| 210 | metric_tensors.append(metric_tensor) |
| 211 | add_moving_summary(*metric_tensors) |
| 212 | |
| 213 | return total_loss |
| 214 | |
| 215 | trainer.setup_graph( |
| 216 | input_signature + target_signature, |
nothing calls this directly
no test coverage detected