* Runs a single gradient update on a single batch of data. * * This method differs from `fit()` and `fitDataset()` in the following * regards: * - It operates on exactly one batch of data. * - It returns only the loss and metric values, instead of * returning the batch-by-b
(
x: Tensor|Tensor[]|{[inputName: string]: Tensor},
y: Tensor|Tensor[]|
{[inputName: string]: Tensor})
| 1817 | * @doc {heading: 'Models', subheading: 'Classes'} |
| 1818 | */ |
| 1819 | async trainOnBatch( |
| 1820 | x: Tensor|Tensor[]|{[inputName: string]: Tensor}, |
| 1821 | y: Tensor|Tensor[]| |
| 1822 | {[inputName: string]: Tensor}): Promise<number|number[]> { |
| 1823 | // TODO(cais): Support sampleWeight and classWeight. |
| 1824 | // TODO(cais): Support Dataset objects. |
| 1825 | const standardizeOut = await this.standardizeUserData(x, y); |
| 1826 | const inputs = standardizeOut[0]; |
| 1827 | const targets = standardizeOut[1]; |
| 1828 | const trainFunction = this.makeTrainFunction(); |
| 1829 | const losses = trainFunction(inputs.concat(targets)); |
| 1830 | const lossValues: number[] = []; |
| 1831 | for (const loss of losses) { |
| 1832 | const v = await loss.data(); |
| 1833 | lossValues.push(v[0]); |
| 1834 | } |
| 1835 | tfc.dispose(losses); |
| 1836 | disposeNewTensors(standardizeOut[0], x); |
| 1837 | disposeNewTensors(standardizeOut[1], y); |
| 1838 | return singletonOrArray(lossValues); |
| 1839 | } |
| 1840 | |
| 1841 | /** |
| 1842 | * Extract weight values of the model. |
no test coverage detected