MCPcopy Index your code
hub / github.com/tensorflow/tfjs / checkLossAndTargetCompatibility

Function checkLossAndTargetCompatibility

tfjs-layers/src/engine/training.ts:230–269  ·  view source on GitHub ↗

* Validation on the compatibility of targes and loss functions. * * This helps prevent users from using loss functions incorrectly. * * @param targets `Array` of `tf.Tensor`s of targets. * @param lossFns `Array` of loss functions. * @param outputShapes `Array` of shapes of model outputs.

(
    targets: Tensor[], lossFns: LossOrMetricFn[], outputShapes: Shape[])

Source from the content-addressed store, hash-verified

228 * @param outputShapes `Array` of shapes of model outputs.
229 */
230function checkLossAndTargetCompatibility(
231 targets: Tensor[], lossFns: LossOrMetricFn[], outputShapes: Shape[]) {
232 // TODO(cais): Dedicated test coverage?
233 const keyLosses = [
234 losses.meanSquaredError, losses.binaryCrossentropy,
235 losses.categoricalCrossentropy
236 ];
237 for (let i = 0; i < targets.length; ++i) {
238 const y = targets[i];
239 const loss = lossFns[i];
240 const shape = outputShapes[i];
241 if (loss == null) {
242 continue;
243 }
244 if (loss === losses.categoricalCrossentropy) {
245 if (y.shape[y.shape.length - 1] === 1) {
246 throw new ValueError(
247 `You are passing a target array of shape ${y.shape} while using ` +
248 `a loss 'categorical_crossentropy'. 'categorical_crossentropy'` +
249 `expects targets to be binary matrices (1s and 0s) of shape ` +
250 `[samples, classes].`);
251 // TODO(cais): Example code in error message.
252 }
253 }
254 if (keyLosses.indexOf(loss) !== -1) {
255 const slicedYShape = y.shape.slice(1);
256 const slicedShape = shape.slice(1);
257 for (let j = 0; j < slicedYShape.length; ++j) {
258 const targetDim = slicedYShape[j];
259 const outDim = slicedShape[j];
260 if (outDim != null && targetDim !== outDim) {
261 throw new ValueError(
262 `A target Tensor with shape ${y.shape} was passed for an ` +
263 `output of shape ${shape}, while using a loss function that ` +
264 `expects targets to have the same shape as the output.`);
265 }
266 }
267 }
268 }
269}
270
271/**
272 * Check inputs provided by the user.

Callers 1

standardizeUserDataXYMethod · 0.85

Calls 1

sliceMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…