* Creates a function that performs the following actions: * * 1. computes the losses * 2. sums them to get the total loss * 3. call the optimizer computes the gradients of the LayersModel's * trainable weights w.r.t. the total loss and update the variables * 4. calculates the me
()
| 1296 | * 5. returns the values of the losses and metrics. |
| 1297 | */ |
| 1298 | protected makeTrainFunction(): (data: Tensor[]) => Scalar[] { |
| 1299 | return (data: Tensor[]) => { |
| 1300 | const lossValues: Scalar[] = []; |
| 1301 | |
| 1302 | const inputs = data.slice(0, this.inputs.length); |
| 1303 | const targets = data.slice( |
| 1304 | this.inputs.length, this.inputs.length + this.outputs.length); |
| 1305 | const sampleWeights = data.slice( |
| 1306 | this.inputs.length + this.outputs.length, |
| 1307 | this.inputs.length + this.outputs.length * 2); |
| 1308 | |
| 1309 | const metricsValues: Scalar[] = []; |
| 1310 | |
| 1311 | // Create a function that computes the total loss based on the |
| 1312 | // inputs. This function is used for obtaining gradients through |
| 1313 | // backprop. |
| 1314 | const totalLossFunction = () => { |
| 1315 | const feeds = []; |
| 1316 | for (let i = 0; i < this.inputs.length; ++i) { |
| 1317 | feeds.push({key: this.inputs[i], value: inputs[i]}); |
| 1318 | } |
| 1319 | const feedDict = new FeedDict(feeds); |
| 1320 | const outputs = |
| 1321 | execute(this.outputs, feedDict, {'training': true}) as Tensor[]; |
| 1322 | // TODO(cais): Take care of the case of multiple outputs from a |
| 1323 | // single layer? |
| 1324 | |
| 1325 | let totalLoss: Tensor; |
| 1326 | for (let i = 0; i < this.lossFunctions.length; ++i) { |
| 1327 | const lossFunction = this.lossFunctions[i]; |
| 1328 | let loss = lossFunction(targets[i], outputs[i]); |
| 1329 | if (sampleWeights[i] != null) { |
| 1330 | loss = computeWeightedLoss(loss, sampleWeights[i]); |
| 1331 | } |
| 1332 | |
| 1333 | // TODO(cais): push Scalar instead. |
| 1334 | const meanLoss: Scalar = tfc.mean(loss); |
| 1335 | // TODO(cais): Use a scope() instead, to avoid ownership. |
| 1336 | lossValues.push(meanLoss); |
| 1337 | if (i === 0) { |
| 1338 | totalLoss = loss; |
| 1339 | } else { |
| 1340 | totalLoss = tfc.add(totalLoss, loss); |
| 1341 | } |
| 1342 | } |
| 1343 | |
| 1344 | // Compute the metrics. |
| 1345 | // TODO(cais): These should probably be calculated outside |
| 1346 | // totalLossFunction to benefit speed? |
| 1347 | for (let i = 0; i < this.metricsTensors.length; ++i) { |
| 1348 | let weightedMetric: Scalar; |
| 1349 | |
| 1350 | if (this.outputs.length > 1 && i < this.outputs.length) { |
| 1351 | weightedMetric = lossValues[i]; |
| 1352 | } else { |
| 1353 | const metric = this.metricsTensors[i][0]; |
| 1354 | const outputIndex = this.metricsTensors[i][1]; |
| 1355 | weightedMetric = |
no test coverage detected