()
| 7 | |
| 8 | |
| 9 | def main(): |
| 10 | if args.worker_address: |
| 11 | worker_addr = args.worker_address |
| 12 | else: |
| 13 | controller_addr = args.controller_address |
| 14 | ret = requests.post(controller_addr + "/refresh_all_workers") |
| 15 | ret = requests.post(controller_addr + "/list_models") |
| 16 | models = ret.json()["models"] |
| 17 | models.sort() |
| 18 | print(f"Models: {models}") |
| 19 | |
| 20 | ret = requests.post(controller_addr + "/get_worker_address", |
| 21 | json={"model": args.model_name}) |
| 22 | worker_addr = ret.json()["address"] |
| 23 | print(f"worker_addr: {worker_addr}") |
| 24 | |
| 25 | if worker_addr == "": |
| 26 | return |
| 27 | |
| 28 | conv = default_conversation.copy() |
| 29 | conv.append_message(conv.roles[0], args.message) |
| 30 | prompt = conv.get_prompt() |
| 31 | |
| 32 | headers = {"User-Agent": "LLaVA Client"} |
| 33 | pload = { |
| 34 | "model": args.model_name, |
| 35 | "prompt": prompt, |
| 36 | "max_new_tokens": args.max_new_tokens, |
| 37 | "temperature": 0.7, |
| 38 | "stop": conv.sep, |
| 39 | } |
| 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, |
| 41 | json=pload, stream=True) |
| 42 | |
| 43 | print(prompt.replace(conv.sep, "\n"), end="") |
| 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): |
| 45 | if chunk: |
| 46 | data = json.loads(chunk.decode("utf-8")) |
| 47 | output = data["text"].split(conv.sep)[-1] |
| 48 | print(output, end="\r") |
| 49 | print("") |
| 50 | |
| 51 | |
| 52 | if __name__ == "__main__": |
no test coverage detected