(model_name, messages, temperature, top_p, max_new_tokens)
| 376 | |
| 377 | |
| 378 | def mistral_api_stream_iter(model_name, messages, temperature, top_p, max_new_tokens): |
| 379 | from mistralai.client import MistralClient |
| 380 | from mistralai.models.chat_completion import ChatMessage |
| 381 | |
| 382 | api_key = os.environ["MISTRAL_API_KEY"] |
| 383 | |
| 384 | client = MistralClient(api_key=api_key) |
| 385 | |
| 386 | # Make requests |
| 387 | gen_params = { |
| 388 | "model": model_name, |
| 389 | "prompt": messages, |
| 390 | "temperature": temperature, |
| 391 | "top_p": top_p, |
| 392 | "max_new_tokens": max_new_tokens, |
| 393 | } |
| 394 | logger.info(f"==== request ====\n{gen_params}") |
| 395 | |
| 396 | new_messages = [ |
| 397 | ChatMessage(role=message["role"], content=message["content"]) |
| 398 | for message in messages |
| 399 | ] |
| 400 | |
| 401 | res = client.chat_stream( |
| 402 | model=model_name, |
| 403 | temperature=temperature, |
| 404 | messages=new_messages, |
| 405 | max_tokens=max_new_tokens, |
| 406 | top_p=top_p, |
| 407 | ) |
| 408 | |
| 409 | text = "" |
| 410 | for chunk in res: |
| 411 | if chunk.choices[0].delta.content is not None: |
| 412 | text += chunk.choices[0].delta.content |
| 413 | data = { |
| 414 | "text": text, |
| 415 | "error_code": 0, |
| 416 | } |
| 417 | yield data |
| 418 | |
| 419 | |
| 420 | def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base): |
no test coverage detected
searching dependent graphs…