MCPcopy
hub / github.com/tensorflow/tfjs / totalLossFunction

Method totalLossFunction

tfjs-layers/src/engine/training.ts:1314–1372  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

1312 // inputs. This function is used for obtaining gradients through
1313 // backprop.
1314 const totalLossFunction = () => {
1315 const feeds = [];
1316 for (let i = 0; i < this.inputs.length; ++i) {
1317 feeds.push({key: this.inputs[i], value: inputs[i]});
1318 }
1319 const feedDict = new FeedDict(feeds);
1320 const outputs =
1321 execute(this.outputs, feedDict, {'training': true}) as Tensor[];
1322 // TODO(cais): Take care of the case of multiple outputs from a
1323 // single layer?
1324
1325 let totalLoss: Tensor;
1326 for (let i = 0; i < this.lossFunctions.length; ++i) {
1327 const lossFunction = this.lossFunctions[i];
1328 let loss = lossFunction(targets[i], outputs[i]);
1329 if (sampleWeights[i] != null) {
1330 loss = computeWeightedLoss(loss, sampleWeights[i]);
1331 }
1332
1333 // TODO(cais): push Scalar instead.
1334 const meanLoss: Scalar = tfc.mean(loss);
1335 // TODO(cais): Use a scope() instead, to avoid ownership.
1336 lossValues.push(meanLoss);
1337 if (i === 0) {
1338 totalLoss = loss;
1339 } else {
1340 totalLoss = tfc.add(totalLoss, loss);
1341 }
1342 }
1343
1344 // Compute the metrics.
1345 // TODO(cais): These should probably be calculated outside
1346 // totalLossFunction to benefit speed?
1347 for (let i = 0; i < this.metricsTensors.length; ++i) {
1348 let weightedMetric: Scalar;
1349
1350 if (this.outputs.length > 1 && i < this.outputs.length) {
1351 weightedMetric = lossValues[i];
1352 } else {
1353 const metric = this.metricsTensors[i][0];
1354 const outputIndex = this.metricsTensors[i][1];
1355 weightedMetric =
1356 tfc.mean(metric(targets[outputIndex], outputs[outputIndex]));
1357 }
1358
1359 tfc.keep(weightedMetric);
1360 // TODO(cais): Use a scope() instead, to avoid ownership.
1361 metricsValues.push(weightedMetric);
1362 }
1363
1364 totalLoss = tfc.mean(totalLoss);
1365
1366 // Add regularizer penalties.
1367 this.calculateLosses().forEach(regularizerLoss => {
1368 totalLoss = tfc.add(totalLoss, regularizerLoss);
1369 });
1370
1371 return totalLoss as Scalar;

Callers

nothing calls this directly

Calls 6

executeFunction · 0.90
computeWeightedLossFunction · 0.90
meanMethod · 0.80
keepMethod · 0.80
addMethod · 0.65
pushMethod · 0.45

Tested by

no test coverage detected