(dim, position)
| 14 | |
| 15 | |
| 16 | def sinusoidal_embedding_1d(dim, position): |
| 17 | # preprocess |
| 18 | assert dim % 2 == 0 |
| 19 | half = dim // 2 |
| 20 | position = position.type(torch.float64) |
| 21 | |
| 22 | # calculation |
| 23 | sinusoid = torch.outer( |
| 24 | position, torch.pow(10000, -torch.arange(half).to(position).div(half))) |
| 25 | x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) |
| 26 | return x |
| 27 | |
| 28 | |
| 29 | @amp.autocast(enabled=False) |