MCPcopy
hub / github.com/tensorflow/tfjs / constructor

Method constructor

tfjs-layers/src/layers/embeddings.ts:92–126  ·  view source on GitHub ↗
(args: EmbeddingLayerArgs)

Source from the content-addressed store, hash-verified

90 private readonly embeddingsConstraint?: Constraint;
91
92 constructor(args: EmbeddingLayerArgs) {
93 super(args);
94 if (args.batchInputShape == null && args.inputShape == null) {
95 // Porting Note: This logic is copied from Layer's constructor, since we
96 // can't do exactly what the Python constructor does for Embedding().
97 // Specifically, the super constructor can not be called after the
98 // mutation of the `config` argument.
99 let batchSize: number = null;
100 if (args.batchSize != null) {
101 batchSize = args.batchSize;
102 }
103 if (args.inputLength == null) {
104 // Fix super-constructor to what it would have done if
105 // 'config.inputShape' were (None, )
106 this.batchInputShape = [batchSize, null];
107 } else {
108 // Fix super-constructor to what it would have done if
109 // 'config.inputShape' were (config.inputLength, )
110 this.batchInputShape =
111 [batchSize].concat(generic_utils.toList(args.inputLength));
112 }
113 }
114 this.inputDim = args.inputDim;
115 generic_utils.assertPositiveInteger(this.inputDim, 'inputDim');
116 this.outputDim = args.outputDim;
117 generic_utils.assertPositiveInteger(this.outputDim, 'outputDim');
118 this.embeddingsInitializer = getInitializer(
119 args.embeddingsInitializer || this.DEFAULT_EMBEDDINGS_INITIALIZER);
120 this.embeddingsRegularizer = getRegularizer(args.embeddingsRegularizer);
121 this.activityRegularizer = getRegularizer(args.activityRegularizer);
122 this.embeddingsConstraint = getConstraint(args.embeddingsConstraint);
123 this.maskZero = args.maskZero;
124 this.supportsMasking = args.maskZero;
125 this.inputLength = args.inputLength;
126 }
127
128 public override build(inputShape: Shape|Shape[]): void {
129 this.embeddings = this.addWeight(

Callers

nothing calls this directly

Calls 4

getInitializerFunction · 0.90
getRegularizerFunction · 0.90
getConstraintFunction · 0.90
concatMethod · 0.65

Tested by

no test coverage detected