MCPcopy
hub / github.com/tensorflow/tfjs / fitDataset

Function fitDataset

tfjs-layers/src/engine/training_dataset.ts:301–500  ·  view source on GitHub ↗
(
    // Type `model` as `any` here to avoid circular dependency w/
    // training.ts.
    // tslint:disable-next-line:no-any
    model: any, dataset: Dataset<T>,
    args: ModelFitDatasetArgs<T>)

Source from the content-addressed store, hash-verified

299}
300
301export async function fitDataset<T>(
302 // Type `model` as `any` here to avoid circular dependency w/
303 // training.ts.
304 // tslint:disable-next-line:no-any
305 model: any, dataset: Dataset<T>,
306 args: ModelFitDatasetArgs<T>): Promise<History> {
307 const hasBatchesPerEpoch = args.batchesPerEpoch != null;
308 tfc.util.assert(
309 model.optimizer != null,
310 () => 'You must compile a model before training/testing. Use ' +
311 'LayersModel.compile(modelCompileConfig).');
312
313 tfc.util.assert(
314 args != null,
315 () => `For fitDataset(), the 2nd argument (config) is required, ` +
316 `but it is not provided in this call.`);
317 tfc.util.assert(
318 args.epochs != null && args.epochs > 0 && Number.isInteger(args.epochs),
319 () => `For fitDataset(), config.epochs is expected to be a positive ` +
320 `integer, but got ${args.epochs}`);
321 tfc.util.assert(
322 !hasBatchesPerEpoch ||
323 (args.batchesPerEpoch > 0 && Number.isInteger(args.batchesPerEpoch)),
324 () => `For fitDataset(), config.batchesPerEpoch is expected to be a ` +
325 `positive integer if specified, but got ${args.batchesPerEpoch}`);
326 tfc.util.assert(
327 // tslint:disable-next-line:no-any
328 (args as any)['validationSplit'] == null,
329 () => '`validationSplit` is not supported by `fitDataset()`. ' +
330 'Use validationData instead.');
331
332 if (model.isTraining) {
333 throw new Error(
334 'Cannot start training because another fit() call is ongoing.');
335 }
336 model.isTraining = true;
337
338 try {
339 const doValidation = args.validationData != null;
340 let valXs: tfc.Tensor|tfc.Tensor[];
341 let valYs: tfc.Tensor|tfc.Tensor[];
342 if (doValidation) {
343 if (isDatasetObject(args.validationData)) {
344 tfc.util.assert(
345 args.validationBatches == null ||
346 (args.validationBatches > 0 &&
347 Number.isInteger(args.validationBatches)),
348 () => `For fitDataset() with dataset-based validation, ` +
349 `config.validationBatches is expected not to be provided, ` +
350 `or to be a positive integer, ` +
351 `but got ${args.validationBatches}`);
352 } else {
353 const validationData = standardizeTensorValidationData(
354 args.validationData as
355 [tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[]] |
356 [
357 tfc.Tensor | tfc.Tensor[], tfc.Tensor | tfc.Tensor[],
358 tfc.Tensor | tfc.Tensor[]

Callers 1

fitDatasetMethod · 0.90

Calls 15

standardizeCallbacksFunction · 0.90
configureCallbacksFunction · 0.90
standardizeClassWeightsFunction · 0.90
standardizeWeightsFunction · 0.90
disposeTensorsInLogsFunction · 0.90
toListFunction · 0.90
isDatasetObjectFunction · 0.85
getStepsPerEpochFunction · 0.85
makeTrainFunctionMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…