MCPcopy
hub / github.com/tensorflow/tfjs / apply

Method apply

tfjs-layers/src/initializers.ts:332–357  ·  view source on GitHub ↗
(shape: Shape, dtype?: DataType)

Source from the content-addressed store, hash-verified

330 }
331
332 apply(shape: Shape, dtype?: DataType): Tensor {
333 const fans = computeFans(shape);
334 const fanIn = fans[0];
335 const fanOut = fans[1];
336 let scale = this.scale;
337 if (this.mode === 'fanIn') {
338 scale /= Math.max(1, fanIn);
339 } else if (this.mode === 'fanOut') {
340 scale /= Math.max(1, fanOut);
341 } else {
342 scale /= Math.max(1, (fanIn + fanOut) / 2);
343 }
344
345 if (this.distribution === 'normal') {
346 const stddev = Math.sqrt(scale);
347 dtype = dtype || 'float32';
348 if (dtype !== 'float32' && dtype !== 'int32') {
349 throw new NotImplementedError(
350 `${this.getClassName()} does not support dType ${dtype}.`);
351 }
352 return truncatedNormal(shape, 0, stddev, dtype, this.seed);
353 } else {
354 const limit = Math.sqrt(3 * scale);
355 return randomUniform(shape, -limit, limit, dtype, this.seed);
356 }
357 }
358
359 override getConfig(): serialization.ConfigDict {
360 return {

Callers

nothing calls this directly

Calls 6

computeFansFunction · 0.85
truncatedNormalFunction · 0.85
randomUniformFunction · 0.85
maxMethod · 0.80
sqrtMethod · 0.80
getClassNameMethod · 0.65

Tested by

no test coverage detected