* Abstract fit function for `f(ins)`. * @param f A Function returning a list of tensors. For training, this * function is expected to perform the updates to the variables. * @param ins List of tensors to be fed to `f`. * @param outLabels List of strings, display names of the outputs of
(
f: (data: Tensor[]) => Scalar[], ins: Tensor[], outLabels?:
string[], batchSize?: number, epochs?: number, verbose?: number,
callbacks?: BaseCallback[], valF?: (data: Tensor[]) => Scalar[], valIns?:
Tensor[], shuffle?: boolean|string, callbackMetrics?: string[],
initialEpoch?: number, stepsPerEpoch?: number, validationSteps?: number)
| 1629 | * @returns A `History` object. |
| 1630 | */ |
| 1631 | async fitLoop( |
| 1632 | f: (data: Tensor[]) => Scalar[], ins: Tensor[], outLabels?: |
| 1633 | string[], batchSize?: number, epochs?: number, verbose?: number, |
| 1634 | callbacks?: BaseCallback[], valF?: (data: Tensor[]) => Scalar[], valIns?: |
| 1635 | Tensor[], shuffle?: boolean|string, callbackMetrics?: string[], |
| 1636 | initialEpoch?: number, stepsPerEpoch?: number, validationSteps?: number): |
| 1637 | Promise<History> { |
| 1638 | if (batchSize == null) { |
| 1639 | batchSize = 32; |
| 1640 | } |
| 1641 | if (epochs == null) { |
| 1642 | epochs = 1; |
| 1643 | } |
| 1644 | if (shuffle == null) { |
| 1645 | shuffle = true; |
| 1646 | } |
| 1647 | if (initialEpoch == null) { |
| 1648 | initialEpoch = 0; |
| 1649 | } |
| 1650 | |
| 1651 | // TODO(cais): Change const to let below when implementing validation. |
| 1652 | let doValidation = false; |
| 1653 | if (valF != null && valIns != null) { |
| 1654 | doValidation = true; |
| 1655 | // TODO(cais): verbose message. |
| 1656 | } |
| 1657 | if (validationSteps != null) { |
| 1658 | doValidation = true; |
| 1659 | if (stepsPerEpoch == null) { |
| 1660 | throw new ValueError( |
| 1661 | 'Can only use `validationSteps` when doing step-wise training, ' + |
| 1662 | 'i.e., `stepsPerEpoch` must be set.'); |
| 1663 | } |
| 1664 | } |
| 1665 | |
| 1666 | const numTrainSamples = |
| 1667 | this.checkNumSamples(ins, batchSize, stepsPerEpoch, 'steps_per_epoch'); |
| 1668 | let indexArray: number[]; |
| 1669 | if (numTrainSamples != null) { |
| 1670 | indexArray = range(0, numTrainSamples); |
| 1671 | } |
| 1672 | |
| 1673 | if (verbose == null) { |
| 1674 | verbose = 1; |
| 1675 | } |
| 1676 | |
| 1677 | const {callbackList, history} = configureCallbacks( |
| 1678 | callbacks, verbose, epochs, initialEpoch, numTrainSamples, |
| 1679 | stepsPerEpoch, batchSize, doValidation, callbackMetrics); |
| 1680 | callbackList.setModel(this); |
| 1681 | this.history = history; |
| 1682 | await callbackList.onTrainBegin(); |
| 1683 | this.stopTraining_ = false; |
| 1684 | // TODO(cais): Take care of callbacks.validation_data as in PyKeras. |
| 1685 | // TODO(cais): Pre-convert feeds for performance as in PyKeras. |
| 1686 | |
| 1687 | for (let epoch = initialEpoch; epoch < epochs; ++epoch) { |
| 1688 | await callbackList.onEpochBegin(epoch); |
no test coverage detected