(tensor: Tensor)
| 4995 | |
| 4996 | @staticmethod |
| 4997 | def rotate_half(tensor: Tensor) -> Tensor: |
| 4998 | # [bs, num_attention_kv_heads, seqlen, attention_head_size] |
| 4999 | assert tensor.ndim() == 4 |
| 5000 | shape_tensor = concat([ |
| 5001 | shape(tensor, i) / 2 if i == (tensor.ndim() - |
| 5002 | 1) else shape(tensor, i) |
| 5003 | for i in range(tensor.ndim()) |
| 5004 | ]) |
| 5005 | last_dim = shape(tensor, tensor.ndim() - 1) / 2 |
| 5006 | x1 = slice(tensor, [0, 0, 0, 0], shape_tensor, [1, 1, 1, 1]) |
| 5007 | x2 = slice(tensor, concat([0, 0, 0, last_dim]), shape_tensor, |
| 5008 | [1, 1, 1, 1]) |
| 5009 | zero = constant( |
| 5010 | np.ascontiguousarray( |
| 5011 | np.zeros([1], dtype=trt_dtype_to_np(tensor.dtype)))) |
| 5012 | x2 = zero - x2 |
| 5013 | x = concat([x2, x1], 3) |
| 5014 | return x |
| 5015 | |
| 5016 | @staticmethod |
| 5017 | def apply_rotary_pos_emb( |
no test coverage detected