(
prompt_text: str,
top_p: float = 0.2,
temperature: float = 0.1,
repetition_penalty: float = 1.1,
max_new_tokens: int = 1024,
truncate_length: int = 1024,
retry: bool = False
)
| 222 | |
| 223 | |
| 224 | def main( |
| 225 | prompt_text: str, |
| 226 | top_p: float = 0.2, |
| 227 | temperature: float = 0.1, |
| 228 | repetition_penalty: float = 1.1, |
| 229 | max_new_tokens: int = 1024, |
| 230 | truncate_length: int = 1024, |
| 231 | retry: bool = False |
| 232 | ): |
| 233 | if 'ci_history' not in st.session_state: |
| 234 | st.session_state.ci_history = [] |
| 235 | |
| 236 | |
| 237 | if prompt_text == "" and retry == False: |
| 238 | print("\n== Clean ==\n") |
| 239 | st.session_state.chat_history = [] |
| 240 | return |
| 241 | |
| 242 | history: list[Conversation] = st.session_state.chat_history |
| 243 | for conversation in history: |
| 244 | conversation.show() |
| 245 | |
| 246 | if retry: |
| 247 | print("\n== Retry ==\n") |
| 248 | last_user_conversation_idx = None |
| 249 | for idx, conversation in enumerate(history): |
| 250 | if conversation.role == Role.USER: |
| 251 | last_user_conversation_idx = idx |
| 252 | if last_user_conversation_idx is not None: |
| 253 | prompt_text = history[last_user_conversation_idx].content |
| 254 | del history[last_user_conversation_idx:] |
| 255 | if prompt_text: |
| 256 | prompt_text = prompt_text.strip() |
| 257 | role = Role.USER |
| 258 | append_conversation(Conversation(role, prompt_text), history) |
| 259 | |
| 260 | placeholder = st.container() |
| 261 | message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") |
| 262 | markdown_placeholder = message_placeholder.empty() |
| 263 | |
| 264 | for _ in range(5): |
| 265 | output_text = '' |
| 266 | for response in client.generate_stream( |
| 267 | system=SYSTEM_PROMPT, |
| 268 | tools=None, |
| 269 | history=history, |
| 270 | do_sample=True, |
| 271 | max_new_token=max_new_tokens, |
| 272 | temperature=temperature, |
| 273 | top_p=top_p, |
| 274 | stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)], |
| 275 | repetition_penalty=repetition_penalty, |
| 276 | ): |
| 277 | token = response.token |
| 278 | if response.token.special: |
| 279 | print("\n==Output:==\n", output_text) |
| 280 | match token.text.strip(): |
| 281 | case '<|user|>': |
nothing calls this directly
no test coverage detected