Prepare the session for generating messages. :param stream: whether to enable streaming :param system_prompt_variables: system prompt variables :param retrieval_log: whether to log retrieval :param chat_completion_messages: chat completion messages :p
(
self,
stream: bool,
system_prompt_variables: Dict,
retrieval_log: bool = False,
chat_completion_messages: List[ChatCompletionAnyMessage] = None,
chat_completion_input_functions: List[ChatCompletionFunction] = None,
)
| 92 | ) |
| 93 | |
| 94 | async def prepare( |
| 95 | self, |
| 96 | stream: bool, |
| 97 | system_prompt_variables: Dict, |
| 98 | retrieval_log: bool = False, |
| 99 | chat_completion_messages: List[ChatCompletionAnyMessage] = None, |
| 100 | chat_completion_input_functions: List[ChatCompletionFunction] = None, |
| 101 | ): |
| 102 | """ |
| 103 | Prepare the session for generating messages. |
| 104 | :param stream: whether to enable streaming |
| 105 | :param system_prompt_variables: system prompt variables |
| 106 | :param retrieval_log: whether to log retrieval |
| 107 | :param chat_completion_messages: chat completion messages |
| 108 | :param chat_completion_input_functions: chat completion input functions |
| 109 | :return: None |
| 110 | """ |
| 111 | |
| 112 | if self.chat and chat_completion_messages is not None: |
| 113 | raise ValueError("chat_completion_messages should be None when chat is not None.") |
| 114 | |
| 115 | if not self.chat and chat_completion_messages is None: |
| 116 | raise ValueError("chat_completion_messages should not be None when chat is None.") |
| 117 | |
| 118 | if chat_completion_input_functions is not None and chat_completion_messages is None: |
| 119 | raise ValueError("chat_completion_input_functions should be None when chat_completion_messages is None.") |
| 120 | |
| 121 | # check chat lock |
| 122 | if self.chat and await self.chat.is_chat_locked(): |
| 123 | raise MessageGenerationInvalidRequestException( |
| 124 | f"Chat {self.chat.chat_id} is locked. Please try again later." |
| 125 | ) |
| 126 | |
| 127 | # Get model |
| 128 | try: |
| 129 | self.model = await get_model(self.assistant.model_id) |
| 130 | except Exception as e: |
| 131 | raise MessageGenerationInvalidRequestException(f"Failed to load model {self.assistant.model_id}.") |
| 132 | |
| 133 | # Check model streaming |
| 134 | if not self.model.allow_streaming() and stream: |
| 135 | raise MessageGenerationInvalidRequestException( |
| 136 | f"Assistant model {self.model.model_id} does not support streaming. " |
| 137 | ) |
| 138 | |
| 139 | # Get chat memory |
| 140 | if self.chat: |
| 141 | self.chat_memory_messages = await get_chat_memory_messages(self.chat) |
| 142 | logger.debug(f"Chat memory: {self.chat_memory_messages}") |
| 143 | else: |
| 144 | # use user input message as chat memory |
| 145 | self.chat_memory_messages = [ |
| 146 | message.model_dump() |
| 147 | for message in chat_completion_messages |
| 148 | if message.role != ChatCompletionRole.SYSTEM |
| 149 | ] |
| 150 | |
| 151 | # Get tools |
no test coverage detected