(stream=True)
| 45 | |
| 46 | |
| 47 | def main(stream=True): |
| 48 | model, tokenizer = init_model() |
| 49 | messages = clear_screen() |
| 50 | while True: |
| 51 | prompt = input(Fore.GREEN + Style.BRIGHT + "\n用户:" + Style.NORMAL) |
| 52 | if prompt.strip() == "exit": |
| 53 | break |
| 54 | if prompt.strip() == "clear": |
| 55 | messages = clear_screen() |
| 56 | continue |
| 57 | if prompt.strip() == 'vim': |
| 58 | prompt = vim_input() |
| 59 | print(prompt) |
| 60 | print(Fore.CYAN + Style.BRIGHT + "\nBaichuan 2:" + Style.NORMAL, end='') |
| 61 | if prompt.strip() == "stream": |
| 62 | stream = not stream |
| 63 | print(Fore.YELLOW + "({}流式生成)\n".format("开启" if stream else "关闭"), end='') |
| 64 | continue |
| 65 | messages.append({"role": "user", "content": prompt}) |
| 66 | if stream: |
| 67 | position = 0 |
| 68 | try: |
| 69 | for response in model.chat(tokenizer, messages, stream=True): |
| 70 | print(response[position:], end='', flush=True) |
| 71 | position = len(response) |
| 72 | if torch.backends.mps.is_available(): |
| 73 | torch.mps.empty_cache() |
| 74 | except KeyboardInterrupt: |
| 75 | pass |
| 76 | print() |
| 77 | else: |
| 78 | response = model.chat(tokenizer, messages) |
| 79 | print(response) |
| 80 | if torch.backends.mps.is_available(): |
| 81 | torch.mps.empty_cache() |
| 82 | messages.append({"role": "assistant", "content": response}) |
| 83 | print(Style.RESET_ALL) |
| 84 | |
| 85 | |
| 86 | if __name__ == "__main__": |
no test coverage detected