MCPcopy Index your code
hub / github.com/karpathy/makemore / MLP

Class MLP

makemore.py:350–394  ·  view source on GitHub ↗

takes the previous block_size tokens, encodes them with a lookup table, concatenates the vectors and predicts the next token with an MLP. Reference: Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf

Source from the content-addressed store, hash-verified

348# MLP language model
349
350class MLP(nn.Module):
351 """
352 takes the previous block_size tokens, encodes them with a lookup table,
353 concatenates the vectors and predicts the next token with an MLP.
354
355 Reference:
356 Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
357 """
358
359 def __init__(self, config):
360 super().__init__()
361 self.block_size = config.block_size
362 self.vocab_size = config.vocab_size
363 self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table
364 # +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
365 # before the beginning of the input sequence
366 self.mlp = nn.Sequential(
367 nn.Linear(self.block_size * config.n_embd, config.n_embd2),
368 nn.Tanh(),
369 nn.Linear(config.n_embd2, self.vocab_size)
370 )
371
372 def get_block_size(self):
373 return self.block_size
374
375 def forward(self, idx, targets=None):
376
377 # gather the word embeddings of the previous 3 words
378 embs = []
379 for k in range(self.block_size):
380 tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
381 idx = torch.roll(idx, 1, 1)
382 idx[:, 0] = self.vocab_size # special <BLANK> token
383 embs.append(tok_emb)
384
385 # concat all of the embeddings together and pass through an MLP
386 x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
387 logits = self.mlp(x)
388
389 # if we are given some desired targets also calculate the loss
390 loss = None
391 if targets is not None:
392 loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
393
394 return logits, loss
395
396# -----------------------------------------------------------------------------
397# Bigram language model

Callers 1

makemore.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected