(rank: number, args: BaseConvLayerArgs)
| 415 | readonly DEFAULT_BIAS_INITIALIZER: InitializerIdentifier = 'zeros'; |
| 416 | |
| 417 | constructor(rank: number, args: BaseConvLayerArgs) { |
| 418 | super(args as LayerArgs); |
| 419 | BaseConv.verifyArgs(args); |
| 420 | this.rank = rank; |
| 421 | generic_utils.assertPositiveInteger(this.rank, 'rank'); |
| 422 | if (this.rank !== 1 && this.rank !== 2 && this.rank !== 3) { |
| 423 | throw new NotImplementedError( |
| 424 | `Convolution layer for rank other than 1, 2, or 3 (${ |
| 425 | this.rank}) is ` + |
| 426 | `not implemented yet.`); |
| 427 | } |
| 428 | this.kernelSize = normalizeArray(args.kernelSize, rank, 'kernelSize'); |
| 429 | this.strides = normalizeArray( |
| 430 | args.strides == null ? 1 : args.strides, rank, 'strides'); |
| 431 | this.padding = args.padding == null ? 'valid' : args.padding; |
| 432 | checkPaddingMode(this.padding); |
| 433 | this.dataFormat = |
| 434 | args.dataFormat == null ? 'channelsLast' : args.dataFormat; |
| 435 | checkDataFormat(this.dataFormat); |
| 436 | this.activation = getActivation(args.activation); |
| 437 | this.useBias = args.useBias == null ? true : args.useBias; |
| 438 | this.biasInitializer = |
| 439 | getInitializer(args.biasInitializer || this.DEFAULT_BIAS_INITIALIZER); |
| 440 | this.biasConstraint = getConstraint(args.biasConstraint); |
| 441 | this.biasRegularizer = getRegularizer(args.biasRegularizer); |
| 442 | this.activityRegularizer = getRegularizer(args.activityRegularizer); |
| 443 | this.dilationRate = normalizeArray( |
| 444 | args.dilationRate == null ? 1 : args.dilationRate, rank, |
| 445 | 'dilationRate'); |
| 446 | if (this.rank === 1 && |
| 447 | (Array.isArray(this.dilationRate) && this.dilationRate.length !== 1)) { |
| 448 | throw new ValueError( |
| 449 | `dilationRate must be a number or an array of a single number ` + |
| 450 | `for 1D convolution, but received ` + |
| 451 | `${JSON.stringify(this.dilationRate)}`); |
| 452 | } else if (this.rank === 2) { |
| 453 | if (typeof this.dilationRate === 'number') { |
| 454 | this.dilationRate = [this.dilationRate, this.dilationRate]; |
| 455 | } else if (this.dilationRate.length !== 2) { |
| 456 | throw new ValueError( |
| 457 | `dilationRate must be a number or array of two numbers for 2D ` + |
| 458 | `convolution, but received ${JSON.stringify(this.dilationRate)}`); |
| 459 | } |
| 460 | } else if (this.rank === 3) { |
| 461 | if (typeof this.dilationRate === 'number') { |
| 462 | this.dilationRate = |
| 463 | [this.dilationRate, this.dilationRate, this.dilationRate]; |
| 464 | } else if (this.dilationRate.length !== 3) { |
| 465 | throw new ValueError( |
| 466 | `dilationRate must be a number or array of three numbers for 3D ` + |
| 467 | `convolution, but received ${JSON.stringify(this.dilationRate)}`); |
| 468 | } |
| 469 | } |
| 470 | } |
| 471 | |
| 472 | protected static verifyArgs(args: BaseConvLayerArgs) { |
| 473 | // Check config.kernelSize type and shape. |
nothing calls this directly
no test coverage detected
searching dependent graphs…