takes the previous block_size tokens, encodes them with a lookup table, concatenates the vectors and predicts the next token with an MLP. Reference: Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
| 348 | # MLP language model |
| 349 | |
| 350 | class MLP(nn.Module): |
| 351 | """ |
| 352 | takes the previous block_size tokens, encodes them with a lookup table, |
| 353 | concatenates the vectors and predicts the next token with an MLP. |
| 354 | |
| 355 | Reference: |
| 356 | Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf |
| 357 | """ |
| 358 | |
| 359 | def __init__(self, config): |
| 360 | super().__init__() |
| 361 | self.block_size = config.block_size |
| 362 | self.vocab_size = config.vocab_size |
| 363 | self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table |
| 364 | # +1 in the line above for a special <BLANK> token that gets inserted if encoding a token |
| 365 | # before the beginning of the input sequence |
| 366 | self.mlp = nn.Sequential( |
| 367 | nn.Linear(self.block_size * config.n_embd, config.n_embd2), |
| 368 | nn.Tanh(), |
| 369 | nn.Linear(config.n_embd2, self.vocab_size) |
| 370 | ) |
| 371 | |
| 372 | def get_block_size(self): |
| 373 | return self.block_size |
| 374 | |
| 375 | def forward(self, idx, targets=None): |
| 376 | |
| 377 | # gather the word embeddings of the previous 3 words |
| 378 | embs = [] |
| 379 | for k in range(self.block_size): |
| 380 | tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) |
| 381 | idx = torch.roll(idx, 1, 1) |
| 382 | idx[:, 0] = self.vocab_size # special <BLANK> token |
| 383 | embs.append(tok_emb) |
| 384 | |
| 385 | # concat all of the embeddings together and pass through an MLP |
| 386 | x = torch.cat(embs, -1) # (b, t, n_embd * block_size) |
| 387 | logits = self.mlp(x) |
| 388 | |
| 389 | # if we are given some desired targets also calculate the loss |
| 390 | loss = None |
| 391 | if targets is not None: |
| 392 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
| 393 | |
| 394 | return logits, loss |
| 395 | |
| 396 | # ----------------------------------------------------------------------------- |
| 397 | # Bigram language model |