(self)
| 1184 | """ |
| 1185 | |
| 1186 | def setup(self): |
| 1187 | self.dropout_layer = nn.Dropout(rate=self.config.dropout) |
| 1188 | |
| 1189 | embed_dim = self.config.d_model |
| 1190 | self.padding_idx = self.config.pad_token_id |
| 1191 | self.embed_scale = ( |
| 1192 | math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0 |
| 1193 | ) |
| 1194 | |
| 1195 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 |
| 1196 | # and adjust num_embeddings appropriately. Other models don't have this hack |
| 1197 | self.offset = 0 |
| 1198 | if self.config.use_absolute_position_embeddings: |
| 1199 | self.embed_positions = nn.Embed( |
| 1200 | self.config.image_length + self.offset, # image length for BOS |
| 1201 | embed_dim, |
| 1202 | embedding_init=jax.nn.initializers.normal(self.config.init_std), |
| 1203 | ) |
| 1204 | |
| 1205 | self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype) |
| 1206 | self.layernorm_embedding = norm( |
| 1207 | self.config.ln_type, dtype=self.dtype, epsilon=1e-05 |
| 1208 | ) |
| 1209 | |
| 1210 | # postln is already applied in every layer |
| 1211 | if self.config.use_final_ln_decoder and self.config.ln_positions != "postln": |
| 1212 | self.final_ln = norm( |
| 1213 | self.config.ln_type, |
| 1214 | dtype=self.dtype, |
| 1215 | epsilon=1e-05, |
| 1216 | use_scale=self.config.force_ln_scale, |
| 1217 | ) |
| 1218 | |
| 1219 | def __call__( |
| 1220 | self, |
nothing calls this directly
no test coverage detected