MCPcopy
hub / github.com/borisdayma/dalle-mini / setup

Method setup

src/dalle_mini/model/modeling.py:1186–1217  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

1184 """
1185
1186 def setup(self):
1187 self.dropout_layer = nn.Dropout(rate=self.config.dropout)
1188
1189 embed_dim = self.config.d_model
1190 self.padding_idx = self.config.pad_token_id
1191 self.embed_scale = (
1192 math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
1193 )
1194
1195 # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
1196 # and adjust num_embeddings appropriately. Other models don't have this hack
1197 self.offset = 0
1198 if self.config.use_absolute_position_embeddings:
1199 self.embed_positions = nn.Embed(
1200 self.config.image_length + self.offset, # image length for BOS
1201 embed_dim,
1202 embedding_init=jax.nn.initializers.normal(self.config.init_std),
1203 )
1204
1205 self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
1206 self.layernorm_embedding = norm(
1207 self.config.ln_type, dtype=self.dtype, epsilon=1e-05
1208 )
1209
1210 # postln is already applied in every layer
1211 if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
1212 self.final_ln = norm(
1213 self.config.ln_type,
1214 dtype=self.dtype,
1215 epsilon=1e-05,
1216 use_scale=self.config.force_ln_scale,
1217 )
1218
1219 def __call__(
1220 self,

Callers

nothing calls this directly

Calls 2

normFunction · 0.70

Tested by

no test coverage detected