| 1312 | // inputs. This function is used for obtaining gradients through |
| 1313 | // backprop. |
| 1314 | const totalLossFunction = () => { |
| 1315 | const feeds = []; |
| 1316 | for (let i = 0; i < this.inputs.length; ++i) { |
| 1317 | feeds.push({key: this.inputs[i], value: inputs[i]}); |
| 1318 | } |
| 1319 | const feedDict = new FeedDict(feeds); |
| 1320 | const outputs = |
| 1321 | execute(this.outputs, feedDict, {'training': true}) as Tensor[]; |
| 1322 | // TODO(cais): Take care of the case of multiple outputs from a |
| 1323 | // single layer? |
| 1324 | |
| 1325 | let totalLoss: Tensor; |
| 1326 | for (let i = 0; i < this.lossFunctions.length; ++i) { |
| 1327 | const lossFunction = this.lossFunctions[i]; |
| 1328 | let loss = lossFunction(targets[i], outputs[i]); |
| 1329 | if (sampleWeights[i] != null) { |
| 1330 | loss = computeWeightedLoss(loss, sampleWeights[i]); |
| 1331 | } |
| 1332 | |
| 1333 | // TODO(cais): push Scalar instead. |
| 1334 | const meanLoss: Scalar = tfc.mean(loss); |
| 1335 | // TODO(cais): Use a scope() instead, to avoid ownership. |
| 1336 | lossValues.push(meanLoss); |
| 1337 | if (i === 0) { |
| 1338 | totalLoss = loss; |
| 1339 | } else { |
| 1340 | totalLoss = tfc.add(totalLoss, loss); |
| 1341 | } |
| 1342 | } |
| 1343 | |
| 1344 | // Compute the metrics. |
| 1345 | // TODO(cais): These should probably be calculated outside |
| 1346 | // totalLossFunction to benefit speed? |
| 1347 | for (let i = 0; i < this.metricsTensors.length; ++i) { |
| 1348 | let weightedMetric: Scalar; |
| 1349 | |
| 1350 | if (this.outputs.length > 1 && i < this.outputs.length) { |
| 1351 | weightedMetric = lossValues[i]; |
| 1352 | } else { |
| 1353 | const metric = this.metricsTensors[i][0]; |
| 1354 | const outputIndex = this.metricsTensors[i][1]; |
| 1355 | weightedMetric = |
| 1356 | tfc.mean(metric(targets[outputIndex], outputs[outputIndex])); |
| 1357 | } |
| 1358 | |
| 1359 | tfc.keep(weightedMetric); |
| 1360 | // TODO(cais): Use a scope() instead, to avoid ownership. |
| 1361 | metricsValues.push(weightedMetric); |
| 1362 | } |
| 1363 | |
| 1364 | totalLoss = tfc.mean(totalLoss); |
| 1365 | |
| 1366 | // Add regularizer penalties. |
| 1367 | this.calculateLosses().forEach(regularizerLoss => { |
| 1368 | totalLoss = tfc.add(totalLoss, regularizerLoss); |
| 1369 | }); |
| 1370 | |
| 1371 | return totalLoss as Scalar; |