(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1)
| 2 | |
| 3 | class TransformerBlock: |
| 4 | def __init__(self, embed_dim, num_heads, ff_dim, prenorm=False, act=lambda x: x.relu(), dropout=0.1): |
| 5 | assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
| 6 | |
| 7 | self.num_heads = num_heads |
| 8 | self.head_size = embed_dim // num_heads |
| 9 | self.prenorm, self.act = prenorm, act |
| 10 | self.dropout = dropout |
| 11 | |
| 12 | self.query = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) |
| 13 | self.key = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) |
| 14 | self.value = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) |
| 15 | |
| 16 | self.out = (Tensor.scaled_uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim)) |
| 17 | |
| 18 | self.ff1 = (Tensor.scaled_uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim)) |
| 19 | self.ff2 = (Tensor.scaled_uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim)) |
| 20 | |
| 21 | self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) |
| 22 | self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim)) |
| 23 | |
| 24 | def attn(self, x): |
| 25 | # x: (bs, time, embed_dim) -> (bs, time, embed_dim) |
nothing calls this directly
no test coverage detected