(prompt)
| 38 | |
| 39 | |
| 40 | def recover_message_list(prompt): |
| 41 | role_token_pattern = "|".join( |
| 42 | [re.escape(r) for r in ["<|system|>", "<|user|>", "<|assistant|>"]] |
| 43 | ) |
| 44 | role = None |
| 45 | last_end_idx = -1 |
| 46 | message_list = [] |
| 47 | for match in re.finditer(role_token_pattern, prompt): |
| 48 | if role: |
| 49 | messge = {} |
| 50 | if role == "<|system|>": |
| 51 | messge["role"] = "system" |
| 52 | elif role == "<|user|>": |
| 53 | messge["role"] = "user" |
| 54 | else: |
| 55 | messge["role"] = "assistant" |
| 56 | messge["content"] = prompt[last_end_idx + 1 : match.start()] |
| 57 | message_list.append(messge) |
| 58 | |
| 59 | role = prompt[match.start() : match.end()] |
| 60 | last_end_idx = match.end() |
| 61 | |
| 62 | return message_list |
| 63 | |
| 64 | |
| 65 | @torch.inference_mode() |
no outgoing calls
no test coverage detected
searching dependent graphs…