| 30 | skip_next: bool = False |
| 31 | |
| 32 | def get_prompt(self): |
| 33 | messages = self.messages |
| 34 | if len(messages) > 0 and type(messages[0][1]) is tuple: |
| 35 | messages = self.messages.copy() |
| 36 | init_role, init_msg = messages[0].copy() |
| 37 | init_msg = init_msg[0].replace("<image>", "").strip() |
| 38 | if 'mmtag' in self.version: |
| 39 | messages[0] = (init_role, init_msg) |
| 40 | messages.insert(0, (self.roles[0], "<Image><image></Image>")) |
| 41 | messages.insert(1, (self.roles[1], "Received.")) |
| 42 | else: |
| 43 | messages[0] = (init_role, "<image>\n" + init_msg) |
| 44 | |
| 45 | if self.sep_style == SeparatorStyle.SINGLE: |
| 46 | ret = self.system + self.sep |
| 47 | for role, message in messages: |
| 48 | if message: |
| 49 | if type(message) is tuple: |
| 50 | message, _, _ = message |
| 51 | ret += role + ": " + message + self.sep |
| 52 | else: |
| 53 | ret += role + ":" |
| 54 | elif self.sep_style == SeparatorStyle.TWO: |
| 55 | seps = [self.sep, self.sep2] |
| 56 | ret = self.system + seps[0] |
| 57 | for i, (role, message) in enumerate(messages): |
| 58 | if message: |
| 59 | if type(message) is tuple: |
| 60 | message, _, _ = message |
| 61 | ret += role + ": " + message + seps[i % 2] |
| 62 | else: |
| 63 | ret += role + ":" |
| 64 | elif self.sep_style == SeparatorStyle.MPT: |
| 65 | ret = self.system + self.sep |
| 66 | for role, message in messages: |
| 67 | if message: |
| 68 | if type(message) is tuple: |
| 69 | message, _, _ = message |
| 70 | ret += role + message + self.sep |
| 71 | else: |
| 72 | ret += role |
| 73 | elif self.sep_style == SeparatorStyle.LLAMA_2: |
| 74 | wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg |
| 75 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" |
| 76 | ret = "" |
| 77 | |
| 78 | for i, (role, message) in enumerate(messages): |
| 79 | if i == 0: |
| 80 | assert message, "first message should not be none" |
| 81 | assert role == self.roles[0], "first message should come from user" |
| 82 | if message: |
| 83 | if type(message) is tuple: |
| 84 | message, _, _ = message |
| 85 | if i == 0: message = wrap_sys(self.system) + message |
| 86 | if i % 2 == 0: |
| 87 | message = wrap_inst(message) |
| 88 | ret += self.sep + message |
| 89 | else: |