* Trains the model for a fixed number of epochs (iterations on a * dataset). * * ```js * const model = tf.sequential({ * layers: [tf.layers.dense({units: 1, inputShape: [10]})] * }); * model.compile({optimizer: 'sgd', loss: 'meanSquaredError'}); * for (let i = 1; i < 5 ;
(
x: Tensor|Tensor[]|{[inputName: string]: Tensor},
y: Tensor|Tensor[]|{[inputName: string]: Tensor},
args: ModelFitArgs = {})
| 1462 | * @doc {heading: 'Models', subheading: 'Classes'} |
| 1463 | */ |
| 1464 | async fit( |
| 1465 | x: Tensor|Tensor[]|{[inputName: string]: Tensor}, |
| 1466 | y: Tensor|Tensor[]|{[inputName: string]: Tensor}, |
| 1467 | args: ModelFitArgs = {}): Promise<History> { |
| 1468 | if (this.isTraining) { |
| 1469 | throw new Error( |
| 1470 | 'Cannot start training because another fit() call is ongoing.'); |
| 1471 | } |
| 1472 | this.isTraining = true; |
| 1473 | let inputs: Tensor[]; |
| 1474 | let targets: Tensor[]; |
| 1475 | let originalInputs: Tensor[]; |
| 1476 | let originalTargets: Tensor[]; |
| 1477 | let inputValX: Tensor|Tensor[]; |
| 1478 | let inputValY: Tensor|Tensor[]; |
| 1479 | let valX: Tensor|Tensor[]; |
| 1480 | let valY: Tensor|Tensor[]; |
| 1481 | let sampleWeights: Tensor[]; |
| 1482 | try { |
| 1483 | const batchSize = args.batchSize == null ? 32 : args.batchSize; |
| 1484 | checkBatchSize(batchSize); |
| 1485 | |
| 1486 | // Validate user data. |
| 1487 | // TODO(cais): Support sampleWeight. |
| 1488 | const checkBatchAxis = false; |
| 1489 | const standardizedOuts = |
| 1490 | await this.standardizeUserData( |
| 1491 | x, y, args.sampleWeight, args.classWeight, checkBatchAxis, |
| 1492 | batchSize) as [Tensor[], Tensor[], Tensor[]]; |
| 1493 | inputs = standardizedOuts[0]; |
| 1494 | targets = standardizedOuts[1]; |
| 1495 | sampleWeights = standardizedOuts[2]; |
| 1496 | |
| 1497 | // Prepare validation data. |
| 1498 | let doValidation = false; |
| 1499 | let valIns: Tensor[]; |
| 1500 | if (args.validationData != null && args.validationData.length > 0) { |
| 1501 | doValidation = true; |
| 1502 | if (args.validationData.length === 2) { |
| 1503 | // config.validationData consists of valX and valY. |
| 1504 | inputValX = args.validationData[0]; |
| 1505 | inputValY = args.validationData[1]; |
| 1506 | } else if (args.validationData.length === 3) { |
| 1507 | throw new NotImplementedError( |
| 1508 | 'validationData including sample weights is not supported yet.'); |
| 1509 | } else { |
| 1510 | throw new ValueError( |
| 1511 | `When passing validation data, it must contain 2 (valX, valY) ` + |
| 1512 | `or 3 (valX, valY, valSampleWeight) items; ` + |
| 1513 | `${args.validationData} is invalid.`); |
| 1514 | } |
| 1515 | |
| 1516 | const checkBatchAxis = true; |
| 1517 | const valStandardized = |
| 1518 | await this.standardizeUserData( |
| 1519 | inputValX, inputValY, null, /** Unused sample weights. */ |
| 1520 | null, /** Unused class weights. */ |
| 1521 | checkBatchAxis, batchSize) as [Tensor[], Tensor[], Tensor[]]; |
no test coverage detected