(
model_name,
model_id,
messages,
temperature,
top_p,
max_new_tokens,
api_key=None,
api_base=None,
)
| 303 | |
| 304 | |
| 305 | def ai2_api_stream_iter( |
| 306 | model_name, |
| 307 | model_id, |
| 308 | messages, |
| 309 | temperature, |
| 310 | top_p, |
| 311 | max_new_tokens, |
| 312 | api_key=None, |
| 313 | api_base=None, |
| 314 | ): |
| 315 | # get keys and needed values |
| 316 | ai2_key = api_key or os.environ.get("AI2_API_KEY") |
| 317 | api_base = api_base or "https://inferd.allen.ai/api/v1/infer" |
| 318 | |
| 319 | # Make requests |
| 320 | gen_params = { |
| 321 | "model": model_name, |
| 322 | "prompt": messages, |
| 323 | "temperature": temperature, |
| 324 | "top_p": top_p, |
| 325 | "max_new_tokens": max_new_tokens, |
| 326 | } |
| 327 | logger.info(f"==== request ====\n{gen_params}") |
| 328 | |
| 329 | # AI2 uses vLLM, which requires that `top_p` be 1.0 for greedy sampling: |
| 330 | # https://github.com/vllm-project/vllm/blob/v0.1.7/vllm/sampling_params.py#L156-L157 |
| 331 | if temperature == 0.0 and top_p < 1.0: |
| 332 | raise ValueError("top_p must be 1 when temperature is 0.0") |
| 333 | |
| 334 | res = requests.post( |
| 335 | api_base, |
| 336 | stream=True, |
| 337 | headers={"Authorization": f"Bearer {ai2_key}"}, |
| 338 | json={ |
| 339 | "model_id": model_id, |
| 340 | # This input format is specific to the Tulu2 model. Other models |
| 341 | # may require different input formats. See the model's schema |
| 342 | # documentation on InferD for more information. |
| 343 | "input": { |
| 344 | "messages": messages, |
| 345 | "opts": { |
| 346 | "max_tokens": max_new_tokens, |
| 347 | "temperature": temperature, |
| 348 | "top_p": top_p, |
| 349 | "logprobs": 1, # increase for more choices |
| 350 | }, |
| 351 | }, |
| 352 | }, |
| 353 | timeout=5, |
| 354 | ) |
| 355 | |
| 356 | if res.status_code != 200: |
| 357 | logger.error(f"unexpected response ({res.status_code}): {res.text}") |
| 358 | raise ValueError("unexpected response from InferD", res) |
| 359 | |
| 360 | text = "" |
| 361 | for line in res.iter_lines(): |
| 362 | if line: |
no outgoing calls
no test coverage detected
searching dependent graphs…