MCPcopy
hub / github.com/tensorflow/tfjs / makeTrainFunction

Method makeTrainFunction

tfjs-layers/src/engine/training.ts:1298–1382  ·  view source on GitHub ↗

* Creates a function that performs the following actions: * * 1. computes the losses * 2. sums them to get the total loss * 3. call the optimizer computes the gradients of the LayersModel's * trainable weights w.r.t. the total loss and update the variables * 4. calculates the me

()

Source from the content-addressed store, hash-verified

1296 * 5. returns the values of the losses and metrics.
1297 */
1298 protected makeTrainFunction(): (data: Tensor[]) => Scalar[] {
1299 return (data: Tensor[]) => {
1300 const lossValues: Scalar[] = [];
1301
1302 const inputs = data.slice(0, this.inputs.length);
1303 const targets = data.slice(
1304 this.inputs.length, this.inputs.length + this.outputs.length);
1305 const sampleWeights = data.slice(
1306 this.inputs.length + this.outputs.length,
1307 this.inputs.length + this.outputs.length * 2);
1308
1309 const metricsValues: Scalar[] = [];
1310
1311 // Create a function that computes the total loss based on the
1312 // inputs. This function is used for obtaining gradients through
1313 // backprop.
1314 const totalLossFunction = () => {
1315 const feeds = [];
1316 for (let i = 0; i < this.inputs.length; ++i) {
1317 feeds.push({key: this.inputs[i], value: inputs[i]});
1318 }
1319 const feedDict = new FeedDict(feeds);
1320 const outputs =
1321 execute(this.outputs, feedDict, {'training': true}) as Tensor[];
1322 // TODO(cais): Take care of the case of multiple outputs from a
1323 // single layer?
1324
1325 let totalLoss: Tensor;
1326 for (let i = 0; i < this.lossFunctions.length; ++i) {
1327 const lossFunction = this.lossFunctions[i];
1328 let loss = lossFunction(targets[i], outputs[i]);
1329 if (sampleWeights[i] != null) {
1330 loss = computeWeightedLoss(loss, sampleWeights[i]);
1331 }
1332
1333 // TODO(cais): push Scalar instead.
1334 const meanLoss: Scalar = tfc.mean(loss);
1335 // TODO(cais): Use a scope() instead, to avoid ownership.
1336 lossValues.push(meanLoss);
1337 if (i === 0) {
1338 totalLoss = loss;
1339 } else {
1340 totalLoss = tfc.add(totalLoss, loss);
1341 }
1342 }
1343
1344 // Compute the metrics.
1345 // TODO(cais): These should probably be calculated outside
1346 // totalLossFunction to benefit speed?
1347 for (let i = 0; i < this.metricsTensors.length; ++i) {
1348 let weightedMetric: Scalar;
1349
1350 if (this.outputs.length > 1 && i < this.outputs.length) {
1351 weightedMetric = lossValues[i];
1352 } else {
1353 const metric = this.metricsTensors[i][0];
1354 const outputIndex = this.metricsTensors[i][1];
1355 weightedMetric =

Callers 3

fitMethod · 0.95
trainOnBatchMethod · 0.95
fitDatasetFunction · 0.80

Calls 3

sliceMethod · 0.65
readMethod · 0.65
concatMethod · 0.65

Tested by

no test coverage detected