MCPcopy
hub / github.com/tensorflow/tfjs / makeTestFunction

Method makeTestFunction

tfjs-layers/src/engine/training.ts:1389–1428  ·  view source on GitHub ↗

* Create a function which, when invoked with an array of `tf.Tensor`s as a * batch of inputs, returns the prespecified loss and metrics of the model * under the batch of input data.

()

Source from the content-addressed store, hash-verified

1387 * under the batch of input data.
1388 */
1389 private makeTestFunction() {
1390 this.testFunction = (data: Tensor[]) => {
1391 return tfc.tidy(() => {
1392 const valOutputs: Scalar[] = [];
1393 let totalLoss: Scalar;
1394 const inputs = data.slice(0, this.inputs.length);
1395 const targets = data.slice(
1396 this.inputs.length, this.inputs.length + this.outputs.length);
1397 const feeds = [];
1398 for (let i = 0; i < this.inputs.length; ++i) {
1399 feeds.push({key: this.inputs[i], value: inputs[i]});
1400 }
1401 const feedDict = new FeedDict(feeds);
1402 const outputs = execute(this.outputs, feedDict) as Tensor[];
1403 // Compute total loss.
1404 for (let i = 0; i < this.lossFunctions.length; ++i) {
1405 const lossFunction = this.lossFunctions[i];
1406 // TODO(cais): Add sample weighting and replace the simple
1407 // averaging.
1408 const loss: Scalar = tfc.mean(lossFunction(targets[i], outputs[i]));
1409 if (i === 0) {
1410 totalLoss = loss;
1411 } else {
1412 totalLoss = tfc.add(totalLoss, loss);
1413 }
1414 valOutputs.push(totalLoss);
1415 }
1416 // Compute the metrics.
1417 for (let i = 0; i < this.metricsTensors.length; ++i) {
1418 const metric = this.metricsTensors[i][0];
1419 const outputIndex = this.metricsTensors[i][1];
1420 // TODO(cais): Replace K.mean() with a proper weighting function.
1421 const meanMetric =
1422 tfc.mean(metric(targets[outputIndex], outputs[outputIndex]));
1423 valOutputs.push(meanMetric as Scalar);
1424 }
1425 return valOutputs;
1426 });
1427 };
1428 }
1429
1430 /**
1431 * Trains the model for a fixed number of epochs (iterations on a

Callers 3

evaluateMethod · 0.95
evaluateDatasetMethod · 0.95
fitMethod · 0.95

Calls 6

executeFunction · 0.90
tidyMethod · 0.80
meanMethod · 0.80
sliceMethod · 0.65
addMethod · 0.65
pushMethod · 0.45

Tested by

no test coverage detected