(
x: Tensor|Tensor[]|{[inputName: string]: Tensor},
y: Tensor|Tensor[]|{[inputName: string]: Tensor}, checkBatchAxis = true,
batchSize?: number)
| 1143 | } |
| 1144 | |
| 1145 | protected standardizeUserDataXY( |
| 1146 | x: Tensor|Tensor[]|{[inputName: string]: Tensor}, |
| 1147 | y: Tensor|Tensor[]|{[inputName: string]: Tensor}, checkBatchAxis = true, |
| 1148 | batchSize?: number): [Tensor[], Tensor[]] { |
| 1149 | // TODO(cais): Add sampleWeight, classWeight |
| 1150 | if (this.optimizer_ == null) { |
| 1151 | throw new RuntimeError( |
| 1152 | 'You must compile a model before training/testing. Use ' + |
| 1153 | 'LayersModel.compile(modelCompileArgs).'); |
| 1154 | } |
| 1155 | const outputShapes: Shape[] = []; |
| 1156 | for (let i = 0; i < this.feedOutputShapes.length; ++i) { |
| 1157 | const outputShape = this.feedOutputShapes[i]; |
| 1158 | const lossFn = this.feedLossFns[i]; |
| 1159 | if (lossFn === losses.sparseCategoricalCrossentropy) { |
| 1160 | outputShapes.push( |
| 1161 | outputShape.slice(0, outputShape.length - 1).concat([1])); |
| 1162 | } else { |
| 1163 | // Porting Note: Because of strong typing `lossFn` must be a function. |
| 1164 | outputShapes.push(outputShape); |
| 1165 | } |
| 1166 | } |
| 1167 | x = standardizeInputData( |
| 1168 | x, this.feedInputNames, this.feedInputShapes, false, 'input'); |
| 1169 | y = standardizeInputData( |
| 1170 | y, this.feedOutputNames, outputShapes, false, 'target'); |
| 1171 | // TODO(cais): Standardize sampleWeights & classWeights. |
| 1172 | checkArrayLengths(x, y, null); |
| 1173 | // TODO(cais): Check sampleWeights as well. |
| 1174 | checkLossAndTargetCompatibility(y, this.feedLossFns, this.feedOutputShapes); |
| 1175 | if (this.stateful && batchSize != null && batchSize > 0) { |
| 1176 | if (x[0].shape[0] % batchSize !== 0) { |
| 1177 | throw new ValueError( |
| 1178 | `In a stateful network, you should only pass inputs with a ` + |
| 1179 | `number of samples that is divisible by the batch size ` + |
| 1180 | `${batchSize}. Found: ${x[0].shape[0]} sample(s).`); |
| 1181 | } |
| 1182 | } |
| 1183 | return [x, y]; |
| 1184 | } |
| 1185 | |
| 1186 | protected async standardizeUserData( |
| 1187 | x: Tensor|Tensor[]|{[inputName: string]: Tensor}, |
no test coverage detected