MCPcopy
hub / github.com/tensorflow/tfjs / apply

Method apply

tfjs-layers/src/initializers.ts:542–585  ·  view source on GitHub ↗
(shape: Shape, dtype?: DataType)

Source from the content-addressed store, hash-verified

540 }
541
542 apply(shape: Shape, dtype?: DataType): Tensor {
543 return tidy(() => {
544 if (shape.length < 2) {
545 throw new NotImplementedError('Shape must be at least 2D.');
546 }
547 if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
548 throw new TypeError(`Unsupported data type ${dtype}.`);
549 }
550 dtype = dtype as 'int32' | 'float32' | undefined;
551
552 // flatten the input shape with the last dimension remaining its
553 // original shape so it works for conv2d
554 const numRows = util.sizeFromShape(shape.slice(0, -1));
555 const numCols = shape[shape.length - 1];
556 const numElements = numRows * numCols;
557 if (numElements > this.ELEMENTS_WARN_SLOW) {
558 console.warn(
559 `Orthogonal initializer is being called on a matrix with more ` +
560 `than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
561 `Slowness may result.`);
562 }
563 const flatShape =
564 [Math.max(numCols, numRows), Math.min(numCols, numRows)];
565
566 // Generate a random matrix
567 const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed);
568
569 // Compute QR factorization
570 const qr = linalg.qr(randNormalMat, false);
571 let qMat = qr[0];
572 const rMat = qr[1];
573
574 // Make Q uniform
575 const diag = rMat.flatten().stridedSlice(
576 [0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)],
577 [Math.min(numCols, numRows) + 1]);
578 qMat = mul(qMat, diag.sign());
579 if (numRows < numCols) {
580 qMat = qMat.transpose();
581 }
582
583 return mul(scalar(this.gain), qMat.reshape(shape));
584 });
585 }
586
587 override getConfig(): serialization.ConfigDict {
588 return {

Callers

nothing calls this directly

Calls 10

tidyFunction · 0.90
scalarFunction · 0.90
maxMethod · 0.80
minMethod · 0.80
stridedSliceMethod · 0.80
flattenMethod · 0.80
signMethod · 0.80
transposeMethod · 0.80
reshapeMethod · 0.80
sliceMethod · 0.65

Tested by

no test coverage detected