| 349 | self.apply(init_weights) |
| 350 | |
| 351 | def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): |
| 352 | b, s = ids.size() |
| 353 | |
| 354 | # causal mask |
| 355 | if mask is None: |
| 356 | mask = torch.tril(torch.ones(1, s, s).to(ids.device)) |
| 357 | elif mask.ndim == 2: |
| 358 | mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) |
| 359 | |
| 360 | # layers |
| 361 | x = self.token_embedding(ids) |
| 362 | x = self.dropout(x) |
| 363 | e = self.pos_embedding(x.size(1), |
| 364 | x.size(1)) if self.shared_pos else None |
| 365 | for block in self.blocks: |
| 366 | x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) |
| 367 | x = self.norm(x) |
| 368 | x = self.dropout(x) |
| 369 | return x |
| 370 | |
| 371 | |
| 372 | class T5Model(nn.Module): |