(model_name, messages, temp, top_p, max_tokens, api_base)
| 418 | |
| 419 | |
| 420 | def nvidia_api_stream_iter(model_name, messages, temp, top_p, max_tokens, api_base): |
| 421 | assert model_name in ["llama2-70b-steerlm-chat", "yi-34b-chat"] |
| 422 | |
| 423 | api_key = os.environ["NVIDIA_API_KEY"] |
| 424 | headers = { |
| 425 | "Authorization": f"Bearer {api_key}", |
| 426 | "accept": "text/event-stream", |
| 427 | "content-type": "application/json", |
| 428 | } |
| 429 | # nvidia api does not accept 0 temperature |
| 430 | if temp == 0.0: |
| 431 | temp = 0.0001 |
| 432 | |
| 433 | payload = { |
| 434 | "messages": messages, |
| 435 | "temperature": temp, |
| 436 | "top_p": top_p, |
| 437 | "max_tokens": max_tokens, |
| 438 | "seed": 42, |
| 439 | "stream": True, |
| 440 | } |
| 441 | logger.info(f"==== request ====\n{payload}") |
| 442 | |
| 443 | response = requests.post( |
| 444 | api_base, headers=headers, json=payload, stream=True, timeout=1 |
| 445 | ) |
| 446 | text = "" |
| 447 | for line in response.iter_lines(): |
| 448 | if line: |
| 449 | data = line.decode("utf-8") |
| 450 | if data.endswith("[DONE]"): |
| 451 | break |
| 452 | data = json.loads(data[6:])["choices"][0]["delta"]["content"] |
| 453 | text += data |
| 454 | yield {"text": text, "error_code": 0} |
no outgoing calls
no test coverage detected
searching dependent graphs…