MCPcopy
hub / github.com/tensorflow/tfjs / trainOnBatch

Method trainOnBatch

tfjs-layers/src/engine/training.ts:1819–1839  ·  view source on GitHub ↗

* Runs a single gradient update on a single batch of data. * * This method differs from `fit()` and `fitDataset()` in the following * regards: * - It operates on exactly one batch of data. * - It returns only the loss and metric values, instead of * returning the batch-by-b

(
      x: Tensor|Tensor[]|{[inputName: string]: Tensor},
      y: Tensor|Tensor[]|
      {[inputName: string]: Tensor})

Source from the content-addressed store, hash-verified

1817 * @doc {heading: 'Models', subheading: 'Classes'}
1818 */
1819 async trainOnBatch(
1820 x: Tensor|Tensor[]|{[inputName: string]: Tensor},
1821 y: Tensor|Tensor[]|
1822 {[inputName: string]: Tensor}): Promise<number|number[]> {
1823 // TODO(cais): Support sampleWeight and classWeight.
1824 // TODO(cais): Support Dataset objects.
1825 const standardizeOut = await this.standardizeUserData(x, y);
1826 const inputs = standardizeOut[0];
1827 const targets = standardizeOut[1];
1828 const trainFunction = this.makeTrainFunction();
1829 const losses = trainFunction(inputs.concat(targets));
1830 const lossValues: number[] = [];
1831 for (const loss of losses) {
1832 const v = await loss.data();
1833 lossValues.push(v[0]);
1834 }
1835 tfc.dispose(losses);
1836 disposeNewTensors(standardizeOut[0], x);
1837 disposeNewTensors(standardizeOut[1], y);
1838 return singletonOrArray(lossValues);
1839 }
1840
1841 /**
1842 * Extract weight values of the model.

Callers 2

training_test.tsFile · 0.45
container_test.tsFile · 0.45

Calls 8

standardizeUserDataMethod · 0.95
makeTrainFunctionMethod · 0.95
disposeNewTensorsFunction · 0.90
singletonOrArrayFunction · 0.90
concatMethod · 0.65
dataMethod · 0.65
pushMethod · 0.45
disposeMethod · 0.45

Tested by

no test coverage detected