| 14 | return 0.5 * input * (1.0 + jt.tanh(math.sqrt(2.0 / math.pi) * output)) |
| 15 | |
| 16 | def fixed_pos_embedding(x, seq_dim=1, seq_len=None): |
| 17 | dim = x.shape[-1] |
| 18 | if seq_len is None: |
| 19 | seq_len = x.shape[seq_dim] |
| 20 | inv_freq = 1.0 / (10000 ** (jt.arange(0, dim, 2) / dim)) |
| 21 | sinusoid_inp = ( |
| 22 | jt.einsum("i , j -> i j", jt.arange(seq_len, dtype=jt.float), inv_freq).float() |
| 23 | ) |
| 24 | if jt.flags.use_tensorcore: |
| 25 | sinusoid_inp = sinusoid_inp.half() |
| 26 | return jt.sin(sinusoid_inp), jt.cos(sinusoid_inp) |
| 27 | |
| 28 | def rotate_every_two(x): |
| 29 | x1 = x[:, :, :, ::2] |