MCPcopy Index your code
hub / github.com/tensorflow/tfjs / standardizeUserDataXY

Method standardizeUserDataXY

tfjs-layers/src/engine/training.ts:1145–1184  ·  view source on GitHub ↗
(
      x: Tensor|Tensor[]|{[inputName: string]: Tensor},
      y: Tensor|Tensor[]|{[inputName: string]: Tensor}, checkBatchAxis = true,
      batchSize?: number)

Source from the content-addressed store, hash-verified

1143 }
1144
1145 protected standardizeUserDataXY(
1146 x: Tensor|Tensor[]|{[inputName: string]: Tensor},
1147 y: Tensor|Tensor[]|{[inputName: string]: Tensor}, checkBatchAxis = true,
1148 batchSize?: number): [Tensor[], Tensor[]] {
1149 // TODO(cais): Add sampleWeight, classWeight
1150 if (this.optimizer_ == null) {
1151 throw new RuntimeError(
1152 'You must compile a model before training/testing. Use ' +
1153 'LayersModel.compile(modelCompileArgs).');
1154 }
1155 const outputShapes: Shape[] = [];
1156 for (let i = 0; i < this.feedOutputShapes.length; ++i) {
1157 const outputShape = this.feedOutputShapes[i];
1158 const lossFn = this.feedLossFns[i];
1159 if (lossFn === losses.sparseCategoricalCrossentropy) {
1160 outputShapes.push(
1161 outputShape.slice(0, outputShape.length - 1).concat([1]));
1162 } else {
1163 // Porting Note: Because of strong typing `lossFn` must be a function.
1164 outputShapes.push(outputShape);
1165 }
1166 }
1167 x = standardizeInputData(
1168 x, this.feedInputNames, this.feedInputShapes, false, 'input');
1169 y = standardizeInputData(
1170 y, this.feedOutputNames, outputShapes, false, 'target');
1171 // TODO(cais): Standardize sampleWeights & classWeights.
1172 checkArrayLengths(x, y, null);
1173 // TODO(cais): Check sampleWeights as well.
1174 checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes);
1175 if (this.stateful && batchSize != null && batchSize > 0) {
1176 if (x[0].shape[0] % batchSize !== 0) {
1177 throw new ValueError(
1178 `In a stateful network, you should only pass inputs with a ` +
1179 `number of samples that is divisible by the batch size ` +
1180 `${batchSize}. Found: ${x[0].shape[0]} sample(s).`);
1181 }
1182 }
1183 return [x, y];
1184 }
1185
1186 protected async standardizeUserData(
1187 x: Tensor|Tensor[]|{[inputName: string]: Tensor},

Callers 2

evaluateMethod · 0.95
standardizeUserDataMethod · 0.95

Calls 6

standardizeInputDataFunction · 0.85
checkArrayLengthsFunction · 0.85
concatMethod · 0.65
sliceMethod · 0.65
pushMethod · 0.45

Tested by

no test coverage detected