(
// Type `model` as `any` here to avoid circular dependency w/
// training.ts.
// tslint:disable-next-line:no-any
model: any, dataset: Dataset<T>,
args: ModelFitDatasetArgs<T>)
| 299 | } |
| 300 | |
| 301 | export async function fitDataset<T>( |
| 302 | // Type `model` as `any` here to avoid circular dependency w/ |
| 303 | // training.ts. |
| 304 | // tslint:disable-next-line:no-any |
| 305 | model: any, dataset: Dataset<T>, |
| 306 | args: ModelFitDatasetArgs<T>): Promise<History> { |
| 307 | const hasBatchesPerEpoch = args.batchesPerEpoch != null; |
| 308 | tfc.util.assert( |
| 309 | model.optimizer != null, |
| 310 | () => 'You must compile a model before training/testing. Use ' + |
| 311 | 'LayersModel.compile(modelCompileConfig).'); |
| 312 | |
| 313 | tfc.util.assert( |
| 314 | args != null, |
| 315 | () => `For fitDataset(), the 2nd argument (config) is required, ` + |
| 316 | `but it is not provided in this call.`); |
| 317 | tfc.util.assert( |
| 318 | args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs), |
| 319 | () => `For fitDataset(), config.epochs is expected to be a positive ` + |
| 320 | `integer, but got ${args.epochs}`); |
| 321 | tfc.util.assert( |
| 322 | !hasBatchesPerEpoch || |
| 323 | (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)), |
| 324 | () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` + |
| 325 | `positive integer if specified, but got ${args.batchesPerEpoch}`); |
| 326 | tfc.util.assert( |
| 327 | // tslint:disable-next-line:no-any |
| 328 | (args as any)['validationSplit'] == null, |
| 329 | () => '`validationSplit` is not supported by `fitDataset()`. ' + |
| 330 | 'Use validationData instead.'); |
| 331 | |
| 332 | if (model.isTraining) { |
| 333 | throw new Error( |
| 334 | 'Cannot start training because another fit() call is ongoing.'); |
| 335 | } |
| 336 | model.isTraining = true; |
| 337 | |
| 338 | try { |
| 339 | const doValidation = args.validationData != null; |
| 340 | let valXs: tfc.Tensor|tfc.Tensor[]; |
| 341 | let valYs: tfc.Tensor|tfc.Tensor[]; |
| 342 | if (doValidation) { |
| 343 | if (isDatasetObject(args.validationData)) { |
| 344 | tfc.util.assert( |
| 345 | args.validationBatches == null || |
| 346 | (args.validationBatches > 0 && |
| 347 | Number.isInteger(args.validationBatches)), |
| 348 | () => `For fitDataset() with dataset-based validation, ` + |
| 349 | `config.validationBatches is expected not to be provided, ` + |
| 350 | `or to be a positive integer, ` + |
| 351 | `but got ${args.validationBatches}`); |
| 352 | } else { |
| 353 | const validationData = standardizeTensorValidationData( |
| 354 | args.validationData as |
| 355 | [tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[]] | |
| 356 | [ |
| 357 | tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[], |
| 358 | tfc.Tensor | tfc.Tensor[] |
no test coverage detected
searching dependent graphs…