| 1392 | |
| 1393 | |
| 1394 | def apply_rotary_pos_emb(t, freqs): |
| 1395 | cos, sin = freqs |
| 1396 | if apply_rotary_emb_func is not None and t.is_cuda: |
| 1397 | t_ = t.float() |
| 1398 | cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] |
| 1399 | sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] |
| 1400 | output = apply_rotary_emb_func(t_, cos, sin).type_as(t) |
| 1401 | return output |
| 1402 | else: |
| 1403 | rot_dim = freqs[0].shape[-1] |
| 1404 | cos, sin = freqs |
| 1405 | t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] |
| 1406 | t_ = t_.float() |
| 1407 | t_pass_ = t_pass_.float() |
| 1408 | t_ = (t_ * cos) + (_rotate_half(t_) * sin) |
| 1409 | return torch.cat((t_, t_pass_), dim=-1).type_as(t) |
| 1410 | |
| 1411 | |
| 1412 | class RMSNorm(torch.nn.Module): |