(shape: Shape, dtype?: DataType)
| 330 | } |
| 331 | |
| 332 | apply(shape: Shape, dtype?: DataType): Tensor { |
| 333 | const fans = computeFans(shape); |
| 334 | const fanIn = fans[0]; |
| 335 | const fanOut = fans[1]; |
| 336 | let scale = this.scale; |
| 337 | if (this.mode === 'fanIn') { |
| 338 | scale /= Math.max(1, fanIn); |
| 339 | } else if (this.mode === 'fanOut') { |
| 340 | scale /= Math.max(1, fanOut); |
| 341 | } else { |
| 342 | scale /= Math.max(1, (fanIn + fanOut) / 2); |
| 343 | } |
| 344 | |
| 345 | if (this.distribution === 'normal') { |
| 346 | const stddev = Math.sqrt(scale); |
| 347 | dtype = dtype || 'float32'; |
| 348 | if (dtype !== 'float32' && dtype !== 'int32') { |
| 349 | throw new NotImplementedError( |
| 350 | `${this.getClassName()} does not support dType ${dtype}.`); |
| 351 | } |
| 352 | return truncatedNormal(shape, 0, stddev, dtype, this.seed); |
| 353 | } else { |
| 354 | const limit = Math.sqrt(3 * scale); |
| 355 | return randomUniform(shape, -limit, limit, dtype, this.seed); |
| 356 | } |
| 357 | } |
| 358 | |
| 359 | override getConfig(): serialization.ConfigDict { |
| 360 | return { |
nothing calls this directly
no test coverage detected