| 572 | private readonly dimsIncludingBatch: number[]; |
| 573 | |
| 574 | constructor(args: PermuteLayerArgs) { |
| 575 | super(args); |
| 576 | if (args.dims == null) { |
| 577 | throw new Error( |
| 578 | 'Required configuration field `dims` is missing during Permute ' + |
| 579 | 'constructor call.'); |
| 580 | } |
| 581 | if (!Array.isArray(args.dims)) { |
| 582 | throw new Error( |
| 583 | 'Permute constructor requires `dims` to be an Array, but received ' + |
| 584 | `${args.dims} instead.`); |
| 585 | } |
| 586 | |
| 587 | // Check the validity of the permutation indices. |
| 588 | const expectedSortedIndices = range(1, args.dims.length + 1); |
| 589 | if (!util.arraysEqual(args.dims.slice().sort(), expectedSortedIndices)) { |
| 590 | throw new Error( |
| 591 | 'Invalid permutation `dims`: ' + JSON.stringify(args.dims) + |
| 592 | ' `dims` must contain consecutive integers starting from 1.'); |
| 593 | } |
| 594 | |
| 595 | this.dims = args.dims; |
| 596 | this.dimsIncludingBatch = [0].concat(this.dims); |
| 597 | this.inputSpec = [new InputSpec({ndim: this.dims.length + 1})]; |
| 598 | } |
| 599 | |
| 600 | override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { |
| 601 | inputShape = getExactlyOneShape(inputShape); |