()
| 8 | |
| 9 | |
| 10 | def main(): |
| 11 | model_name = args.model_name |
| 12 | |
| 13 | if args.worker_address: |
| 14 | worker_addr = args.worker_address |
| 15 | else: |
| 16 | controller_addr = args.controller_address |
| 17 | ret = requests.post(controller_addr + "/refresh_all_workers") |
| 18 | ret = requests.post(controller_addr + "/list_models") |
| 19 | models = ret.json()["models"] |
| 20 | models.sort() |
| 21 | print(f"Models: {models}") |
| 22 | |
| 23 | ret = requests.post( |
| 24 | controller_addr + "/get_worker_address", json={"model": model_name} |
| 25 | ) |
| 26 | worker_addr = ret.json()["address"] |
| 27 | print(f"worker_addr: {worker_addr}") |
| 28 | |
| 29 | if worker_addr == "": |
| 30 | print(f"No available workers for {model_name}") |
| 31 | return |
| 32 | |
| 33 | conv = get_conversation_template(model_name) |
| 34 | conv.append_message(conv.roles[0], args.message) |
| 35 | conv.append_message(conv.roles[1], None) |
| 36 | prompt = conv.get_prompt() |
| 37 | |
| 38 | headers = {"User-Agent": "FastChat Client"} |
| 39 | gen_params = { |
| 40 | "model": model_name, |
| 41 | "prompt": prompt, |
| 42 | "temperature": args.temperature, |
| 43 | "max_new_tokens": args.max_new_tokens, |
| 44 | "stop": conv.stop_str, |
| 45 | "stop_token_ids": conv.stop_token_ids, |
| 46 | "echo": False, |
| 47 | } |
| 48 | response = requests.post( |
| 49 | worker_addr + "/worker_generate_stream", |
| 50 | headers=headers, |
| 51 | json=gen_params, |
| 52 | stream=True, |
| 53 | ) |
| 54 | |
| 55 | print(f"{conv.roles[0]}: {args.message}") |
| 56 | print(f"{conv.roles[1]}: ", end="") |
| 57 | prev = 0 |
| 58 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
| 59 | if chunk: |
| 60 | data = json.loads(chunk.decode()) |
| 61 | output = data["text"].strip() |
| 62 | print(output[prev:], end="", flush=True) |
| 63 | prev = len(output) |
| 64 | print("") |
| 65 | |
| 66 | |
| 67 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…