* Validation on the compatibility of targes and loss functions. * * This helps prevent users from using loss functions incorrectly. * * @param targets `Array` of `tf.Tensor`s of targets. * @param lossFns `Array` of loss functions. * @param outputShapes `Array` of shapes of model outputs.
(
targets: Tensor[], lossFns: LossOrMetricFn[], outputShapes: Shape[])
| 228 | * @param outputShapes `Array` of shapes of model outputs. |
| 229 | */ |
| 230 | function checkLossAndTargetCompatibility( |
| 231 | targets: Tensor[], lossFns: LossOrMetricFn[], outputShapes: Shape[]) { |
| 232 | // TODO(cais): Dedicated test coverage? |
| 233 | const keyLosses = [ |
| 234 | losses.meanSquaredError, losses.binaryCrossentropy, |
| 235 | losses.categoricalCrossentropy |
| 236 | ]; |
| 237 | for (let i = 0; i < targets.length; ++i) { |
| 238 | const y = targets[i]; |
| 239 | const loss = lossFns[i]; |
| 240 | const shape = outputShapes[i]; |
| 241 | if (loss == null) { |
| 242 | continue; |
| 243 | } |
| 244 | if (loss === losses.categoricalCrossentropy) { |
| 245 | if (y.shape[y.shape.length - 1] === 1) { |
| 246 | throw new ValueError( |
| 247 | `You are passing a target array of shape ${y.shape} while using ` + |
| 248 | `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` + |
| 249 | `expects targets to be binary matrices (1s and 0s) of shape ` + |
| 250 | `[samples, classes].`); |
| 251 | // TODO(cais): Example code in error message. |
| 252 | } |
| 253 | } |
| 254 | if (keyLosses.indexOf(loss) !== -1) { |
| 255 | const slicedYShape = y.shape.slice(1); |
| 256 | const slicedShape = shape.slice(1); |
| 257 | for (let j = 0; j < slicedYShape.length; ++j) { |
| 258 | const targetDim = slicedYShape[j]; |
| 259 | const outDim = slicedShape[j]; |
| 260 | if (outDim != null && targetDim !== outDim) { |
| 261 | throw new ValueError( |
| 262 | `A target Tensor with shape ${y.shape} was passed for an ` + |
| 263 | `output of shape ${shape}, while using a loss function that ` + |
| 264 | `expects targets to have the same shape as the output.`); |
| 265 | } |
| 266 | } |
| 267 | } |
| 268 | } |
| 269 | } |
| 270 | |
| 271 | /** |
| 272 | * Check inputs provided by the user. |
no test coverage detected
searching dependent graphs…