MCPcopy
hub / github.com/tensorflow/tfjs / fitLoop

Method fitLoop

tfjs-layers/src/engine/training.ts:1631–1766  ·  view source on GitHub ↗

* Abstract fit function for `f(ins)`. * @param f A Function returning a list of tensors. For training, this * function is expected to perform the updates to the variables. * @param ins List of tensors to be fed to `f`. * @param outLabels List of strings, display names of the outputs of

(
      f: (data: Tensor[]) => Scalar[], ins: Tensor[], outLabels?:
      string[], batchSize?: number, epochs?: number, verbose?: number,
      callbacks?: BaseCallback[], valF?: (data: Tensor[]) => Scalar[], valIns?:
      Tensor[], shuffle?: boolean|string, callbackMetrics?: string[],
      initialEpoch?: number, stepsPerEpoch?: number, validationSteps?: number)

Source from the content-addressed store, hash-verified

1629 * @returns A `History` object.
1630 */
1631 async fitLoop(
1632 f: (data: Tensor[]) => Scalar[], ins: Tensor[], outLabels?:
1633 string[], batchSize?: number, epochs?: number, verbose?: number,
1634 callbacks?: BaseCallback[], valF?: (data: Tensor[]) => Scalar[], valIns?:
1635 Tensor[], shuffle?: boolean|string, callbackMetrics?: string[],
1636 initialEpoch?: number, stepsPerEpoch?: number, validationSteps?: number):
1637 Promise<History> {
1638 if (batchSize == null) {
1639 batchSize = 32;
1640 }
1641 if (epochs == null) {
1642 epochs = 1;
1643 }
1644 if (shuffle == null) {
1645 shuffle = true;
1646 }
1647 if (initialEpoch == null) {
1648 initialEpoch = 0;
1649 }
1650
1651 // TODO(cais): Change const to let below when implementing validation.
1652 let doValidation = false;
1653 if (valF != null && valIns != null) {
1654 doValidation = true;
1655 // TODO(cais): verbose message.
1656 }
1657 if (validationSteps != null) {
1658 doValidation = true;
1659 if (stepsPerEpoch == null) {
1660 throw new ValueError(
1661 'Can only use `validationSteps` when doing step-wise training, ' +
1662 'i.e., `stepsPerEpoch` must be set.');
1663 }
1664 }
1665
1666 const numTrainSamples =
1667 this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch');
1668 let indexArray: number[];
1669 if (numTrainSamples != null) {
1670 indexArray = range(0, numTrainSamples);
1671 }
1672
1673 if (verbose == null) {
1674 verbose = 1;
1675 }
1676
1677 const {callbackList, history} = configureCallbacks(
1678 callbacks, verbose, epochs, initialEpoch, numTrainSamples,
1679 stepsPerEpoch, batchSize, doValidation, callbackMetrics);
1680 callbackList.setModel(this);
1681 this.history = history;
1682 await callbackList.onTrainBegin();
1683 this.stopTraining_ = false;
1684 // TODO(cais): Take care of callbacks.validation_data as in PyKeras.
1685 // TODO(cais): Pre-convert feeds for performance as in PyKeras.
1686
1687 for (let epoch = initialEpoch; epoch < epochs; ++epoch) {
1688 await callbackList.onEpochBegin(epoch);

Callers 1

fitMethod · 0.95

Calls 15

checkNumSamplesMethod · 0.95
testLoopMethod · 0.95
rangeFunction · 0.90
configureCallbacksFunction · 0.90
tensor1dFunction · 0.90
makeBatchesFunction · 0.90
sliceArraysByIndicesFunction · 0.90
disposeTensorsInLogsFunction · 0.90
tidyMethod · 0.80
keepMethod · 0.80
syncDataMethod · 0.80
fFunction · 0.50

Tested by

no test coverage detected