Converts all turns of a dialogue between a user and assistant to a standardized format. Adapted from OpenAI's ChatML (https://github.com/openai/openai-python/blob/main/chatml.md) and Vicuna (https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py)
| 30 | |
| 31 | @dataclass |
| 32 | class DialogueTemplate(ModelHubMixin): |
| 33 | """Converts all turns of a dialogue between a user and assistant to a standardized format. |
| 34 | |
| 35 | Adapted from OpenAI's ChatML (https://github.com/openai/openai-python/blob/main/chatml.md) and Vicuna (https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) |
| 36 | """ |
| 37 | |
| 38 | system: str |
| 39 | messages: List[Dict[str, str]] = None |
| 40 | system_token: str = "<|system|>" |
| 41 | user_token: str = "<|user|>" |
| 42 | assistant_token: str = "<|assistant|>" |
| 43 | end_token: str = "<|end|>" |
| 44 | |
| 45 | def get_training_prompt(self) -> str: |
| 46 | prompt = self.system_token + "\n" + self.system + self.end_token + "\n" |
| 47 | if self.messages is None: |
| 48 | raise ValueError("Dialogue template must have at least one message.") |
| 49 | for message in self.messages: |
| 50 | if message["role"] == "user": |
| 51 | prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n" |
| 52 | else: |
| 53 | prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n" |
| 54 | return prompt |
| 55 | |
| 56 | def get_inference_prompt(self) -> str: |
| 57 | prompt = self.system_token + "\n" + self.system + self.end_token + "\n" |
| 58 | if self.messages is None: |
| 59 | raise ValueError("Dialogue template must have at least one message.") |
| 60 | for message in self.messages: |
| 61 | if message["role"] == "user": |
| 62 | prompt += self.user_token + "\n" + message["content"] + self.end_token + "\n" |
| 63 | else: |
| 64 | prompt += self.assistant_token + "\n" + message["content"] + self.end_token + "\n" |
| 65 | prompt += self.assistant_token |
| 66 | return prompt |
| 67 | |
| 68 | def get_dialogue(self): |
| 69 | """Helper function to format the messages as an easy-to-read dialogue.""" |
| 70 | prompt = "" |
| 71 | if self.messages is None: |
| 72 | raise ValueError("Dialogue template must have at least one message.") |
| 73 | for message in self.messages: |
| 74 | if message["role"] == "user": |
| 75 | prompt += "\n\nHuman: " + message["content"] |
| 76 | else: |
| 77 | prompt += "\n\nAssistant: " + message["content"] |
| 78 | return prompt |
| 79 | |
| 80 | def get_special_tokens(self) -> List[str]: |
| 81 | return [self.system_token, self.user_token, self.assistant_token, self.end_token] |
| 82 | |
| 83 | def copy(self): |
| 84 | return DialogueTemplate( |
| 85 | system=self.system, |
| 86 | messages=self.messages, |
| 87 | system_token=self.system_token, |
| 88 | user_token=self.user_token, |
| 89 | assistant_token=self.assistant_token, |
no outgoing calls
no test coverage detected