MCPcopy
hub / github.com/tensorflow/tfjs / fit

Method fit

tfjs-layers/src/engine/training.ts:1464–1602  ·  view source on GitHub ↗

* Trains the model for a fixed number of epochs (iterations on a * dataset). * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * for (let i = 1; i < 5 ;

(
      x: Tensor|Tensor[]|{[inputName: string]: Tensor},
      y: Tensor|Tensor[]|{[inputName: string]: Tensor},
      args: ModelFitArgs = {})

Source from the content-addressed store, hash-verified

1462 * @doc {heading: 'Models', subheading: 'Classes'}
1463 */
1464 async fit(
1465 x: Tensor|Tensor[]|{[inputName: string]: Tensor},
1466 y: Tensor|Tensor[]|{[inputName: string]: Tensor},
1467 args: ModelFitArgs = {}): Promise<History> {
1468 if (this.isTraining) {
1469 throw new Error(
1470 'Cannot start training because another fit() call is ongoing.');
1471 }
1472 this.isTraining = true;
1473 let inputs: Tensor[];
1474 let targets: Tensor[];
1475 let originalInputs: Tensor[];
1476 let originalTargets: Tensor[];
1477 let inputValX: Tensor|Tensor[];
1478 let inputValY: Tensor|Tensor[];
1479 let valX: Tensor|Tensor[];
1480 let valY: Tensor|Tensor[];
1481 let sampleWeights: Tensor[];
1482 try {
1483 const batchSize = args.batchSize == null ? 32 : args.batchSize;
1484 checkBatchSize(batchSize);
1485
1486 // Validate user data.
1487 // TODO(cais): Support sampleWeight.
1488 const checkBatchAxis = false;
1489 const standardizedOuts =
1490 await this.standardizeUserData(
1491 x, y, args.sampleWeight, args.classWeight, checkBatchAxis,
1492 batchSize) as [Tensor[], Tensor[], Tensor[]];
1493 inputs = standardizedOuts[0];
1494 targets = standardizedOuts[1];
1495 sampleWeights = standardizedOuts[2];
1496
1497 // Prepare validation data.
1498 let doValidation = false;
1499 let valIns: Tensor[];
1500 if (args.validationData != null && args.validationData.length > 0) {
1501 doValidation = true;
1502 if (args.validationData.length === 2) {
1503 // config.validationData consists of valX and valY.
1504 inputValX = args.validationData[0];
1505 inputValY = args.validationData[1];
1506 } else if (args.validationData.length === 3) {
1507 throw new NotImplementedError(
1508 'validationData including sample weights is not supported yet.');
1509 } else {
1510 throw new ValueError(
1511 `When passing validation data, it must contain 2 (valX, valY) ` +
1512 `or 3 (valX, valY, valSampleWeight) items; ` +
1513 `${args.validationData} is invalid.`);
1514 }
1515
1516 const checkBatchAxis = true;
1517 const valStandardized =
1518 await this.standardizeUserData(
1519 inputValX, inputValY, null, /** Unused sample weights. */
1520 null, /** Unused class weights. */
1521 checkBatchAxis, batchSize) as [Tensor[], Tensor[], Tensor[]];

Callers 3

training_test.tsFile · 0.45
container_test.tsFile · 0.45

Calls 14

standardizeUserDataMethod · 0.95
makeTrainFunctionMethod · 0.95
makeTestFunctionMethod · 0.95
fitLoopMethod · 0.95
checkBatchSizeFunction · 0.90
sliceArraysFunction · 0.90
standardizeCallbacksFunction · 0.90
disposeNewTensorsFunction · 0.90
floorMethod · 0.80
concatMethod · 0.65

Tested by

no test coverage detected