(x, c_fc, c_proj)
| 22 | |
| 23 | |
| 24 | def ffn(x, c_fc, c_proj): # [n_seq, n_embd] -> [n_seq, n_embd] |
| 25 | # project up |
| 26 | a = gelu(linear(x, **c_fc)) # [n_seq, n_embd] -> [n_seq, 4*n_embd] |
| 27 | |
| 28 | # project back down |
| 29 | x = linear(a, **c_proj) # [n_seq, 4*n_embd] -> [n_seq, n_embd] |
| 30 | |
| 31 | return x |
| 32 | |
| 33 | |
| 34 | def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v] |
no test coverage detected