* Create a function which, when invoked with an array of `tf.Tensor`s as a * batch of inputs, returns the prespecified loss and metrics of the model * under the batch of input data.
()
| 1387 | * under the batch of input data. |
| 1388 | */ |
| 1389 | private makeTestFunction() { |
| 1390 | this.testFunction = (data: Tensor[]) => { |
| 1391 | return tfc.tidy(() => { |
| 1392 | const valOutputs: Scalar[] = []; |
| 1393 | let totalLoss: Scalar; |
| 1394 | const inputs = data.slice(0, this.inputs.length); |
| 1395 | const targets = data.slice( |
| 1396 | this.inputs.length, this.inputs.length + this.outputs.length); |
| 1397 | const feeds = []; |
| 1398 | for (let i = 0; i < this.inputs.length; ++i) { |
| 1399 | feeds.push({key: this.inputs[i], value: inputs[i]}); |
| 1400 | } |
| 1401 | const feedDict = new FeedDict(feeds); |
| 1402 | const outputs = execute(this.outputs, feedDict) as Tensor[]; |
| 1403 | // Compute total loss. |
| 1404 | for (let i = 0; i < this.lossFunctions.length; ++i) { |
| 1405 | const lossFunction = this.lossFunctions[i]; |
| 1406 | // TODO(cais): Add sample weighting and replace the simple |
| 1407 | // averaging. |
| 1408 | const loss: Scalar = tfc.mean(lossFunction(targets[i], outputs[i])); |
| 1409 | if (i === 0) { |
| 1410 | totalLoss = loss; |
| 1411 | } else { |
| 1412 | totalLoss = tfc.add(totalLoss, loss); |
| 1413 | } |
| 1414 | valOutputs.push(totalLoss); |
| 1415 | } |
| 1416 | // Compute the metrics. |
| 1417 | for (let i = 0; i < this.metricsTensors.length; ++i) { |
| 1418 | const metric = this.metricsTensors[i][0]; |
| 1419 | const outputIndex = this.metricsTensors[i][1]; |
| 1420 | // TODO(cais): Replace K.mean() with a proper weighting function. |
| 1421 | const meanMetric = |
| 1422 | tfc.mean(metric(targets[outputIndex], outputs[outputIndex])); |
| 1423 | valOutputs.push(meanMetric as Scalar); |
| 1424 | } |
| 1425 | return valOutputs; |
| 1426 | }); |
| 1427 | }; |
| 1428 | } |
| 1429 | |
| 1430 | /** |
| 1431 | * Trains the model for a fixed number of epochs (iterations on a |