MCPcopy
hub / github.com/tensorflow/tfjs / constructor

Function constructor

tfjs-layers/src/layers/convolutional.ts:417–470  ·  view source on GitHub ↗
(rank: number, args: BaseConvLayerArgs)

Source from the content-addressed store, hash-verified

415 readonly DEFAULT_BIAS_INITIALIZER: InitializerIdentifier = 'zeros';
416
417 constructor(rank: number, args: BaseConvLayerArgs) {
418 super(args as LayerArgs);
419 BaseConv.verifyArgs(args);
420 this.rank = rank;
421 generic_utils.assertPositiveInteger(this.rank, 'rank');
422 if (this.rank !== 1 && this.rank !== 2 && this.rank !== 3) {
423 throw new NotImplementedError(
424 `Convolution layer for rank other than 1, 2, or 3 (${
425 this.rank}) is ` +
426 `not implemented yet.`);
427 }
428 this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize');
429 this.strides = normalizeArray(
430 args.strides == null ? 1 : args.strides, rank, 'strides');
431 this.padding = args.padding == null ? 'valid' : args.padding;
432 checkPaddingMode(this.padding);
433 this.dataFormat =
434 args.dataFormat == null ? 'channelsLast' : args.dataFormat;
435 checkDataFormat(this.dataFormat);
436 this.activation = getActivation(args.activation);
437 this.useBias = args.useBias == null ? true : args.useBias;
438 this.biasInitializer =
439 getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER);
440 this.biasConstraint = getConstraint(args.biasConstraint);
441 this.biasRegularizer = getRegularizer(args.biasRegularizer);
442 this.activityRegularizer = getRegularizer(args.activityRegularizer);
443 this.dilationRate = normalizeArray(
444 args.dilationRate == null ? 1 : args.dilationRate, rank,
445 'dilationRate');
446 if (this.rank === 1 &&
447 (Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)) {
448 throw new ValueError(
449 `dilationRate must be a number or an array of a single number ` +
450 `for 1D convolution, but received ` +
451 `${JSON.stringify(this.dilationRate)}`);
452 } else if (this.rank === 2) {
453 if (typeof this.dilationRate === 'number') {
454 this.dilationRate = [this.dilationRate, this.dilationRate];
455 } else if (this.dilationRate.length !== 2) {
456 throw new ValueError(
457 `dilationRate must be a number or array of two numbers for 2D ` +
458 `convolution, but received ${JSON.stringify(this.dilationRate)}`);
459 }
460 } else if (this.rank === 3) {
461 if (typeof this.dilationRate === 'number') {
462 this.dilationRate =
463 [this.dilationRate, this.dilationRate, this.dilationRate];
464 } else if (this.dilationRate.length !== 3) {
465 throw new ValueError(
466 `dilationRate must be a number or array of three numbers for 3D ` +
467 `convolution, but received ${JSON.stringify(this.dilationRate)}`);
468 }
469 }
470 }
471
472 protected static verifyArgs(args: BaseConvLayerArgs) {
473 // Check config.kernelSize type and shape.

Callers

nothing calls this directly

Calls 8

normalizeArrayFunction · 0.90
checkPaddingModeFunction · 0.90
checkDataFormatFunction · 0.90
getActivationFunction · 0.90
getInitializerFunction · 0.90
getConstraintFunction · 0.90
getRegularizerFunction · 0.90
verifyArgsMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…