| 540 | } |
| 541 | |
| 542 | apply(shape: Shape, dtype?: DataType): Tensor { |
| 543 | return tidy(() => { |
| 544 | if (shape.length < 2) { |
| 545 | throw new NotImplementedError('Shape must be at least 2D.'); |
| 546 | } |
| 547 | if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) { |
| 548 | throw new TypeError(`Unsupported data type ${dtype}.`); |
| 549 | } |
| 550 | dtype = dtype as 'int32' | 'float32' | undefined; |
| 551 | |
| 552 | // flatten the input shape with the last dimension remaining its |
| 553 | // original shape so it works for conv2d |
| 554 | const numRows = util.sizeFromShape(shape.slice(0, -1)); |
| 555 | const numCols = shape[shape.length - 1]; |
| 556 | const numElements = numRows * numCols; |
| 557 | if (numElements > this.ELEMENTS_WARN_SLOW) { |
| 558 | console.warn( |
| 559 | `Orthogonal initializer is being called on a matrix with more ` + |
| 560 | `than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` + |
| 561 | `Slowness may result.`); |
| 562 | } |
| 563 | const flatShape = |
| 564 | [Math.max(numCols, numRows), Math.min(numCols, numRows)]; |
| 565 | |
| 566 | // Generate a random matrix |
| 567 | const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed); |
| 568 | |
| 569 | // Compute QR factorization |
| 570 | const qr = linalg.qr(randNormalMat, false); |
| 571 | let qMat = qr[0]; |
| 572 | const rMat = qr[1]; |
| 573 | |
| 574 | // Make Q uniform |
| 575 | const diag = rMat.flatten().stridedSlice( |
| 576 | [0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)], |
| 577 | [Math.min(numCols, numRows) + 1]); |
| 578 | qMat = mul(qMat, diag.sign()); |
| 579 | if (numRows < numCols) { |
| 580 | qMat = qMat.transpose(); |
| 581 | } |
| 582 | |
| 583 | return mul(scalar(this.gain), qMat.reshape(shape)); |
| 584 | }); |
| 585 | } |
| 586 | |
| 587 | override getConfig(): serialization.ConfigDict { |
| 588 | return { |