()
| 10 | |
| 11 | |
| 12 | def main(): |
| 13 | if args.worker_address: |
| 14 | worker_addr = args.worker_address |
| 15 | else: |
| 16 | controller_addr = args.controller_address |
| 17 | ret = requests.post(controller_addr + "/refresh_all_workers") |
| 18 | ret = requests.post(controller_addr + "/list_models") |
| 19 | models = ret.json()["models"] |
| 20 | models.sort() |
| 21 | print(f"Models: {models}") |
| 22 | |
| 23 | ret = requests.post( |
| 24 | controller_addr + "/get_worker_address", json={"model": args.model_name} |
| 25 | ) |
| 26 | worker_addr = ret.json()["address"] |
| 27 | print(f"worker_addr: {worker_addr}") |
| 28 | |
| 29 | if worker_addr == "": |
| 30 | return |
| 31 | |
| 32 | conv = get_conv_template("vicuna_v1.1") |
| 33 | conv.append_message(conv.roles[0], "Tell me a story with more than 1000 words") |
| 34 | prompt_template = conv.get_prompt() |
| 35 | prompts = [prompt_template for _ in range(args.n_thread)] |
| 36 | |
| 37 | headers = {"User-Agent": "fastchat Client"} |
| 38 | ploads = [ |
| 39 | { |
| 40 | "model": args.model_name, |
| 41 | "prompt": prompts[i], |
| 42 | "max_new_tokens": args.max_new_tokens, |
| 43 | "temperature": 0.0, |
| 44 | # "stop": conv.sep, |
| 45 | } |
| 46 | for i in range(len(prompts)) |
| 47 | ] |
| 48 | |
| 49 | def send_request(results, i): |
| 50 | if args.test_dispatch: |
| 51 | ret = requests.post( |
| 52 | controller_addr + "/get_worker_address", json={"model": args.model_name} |
| 53 | ) |
| 54 | thread_worker_addr = ret.json()["address"] |
| 55 | else: |
| 56 | thread_worker_addr = worker_addr |
| 57 | print(f"thread {i} goes to {thread_worker_addr}") |
| 58 | response = requests.post( |
| 59 | thread_worker_addr + "/worker_generate_stream", |
| 60 | headers=headers, |
| 61 | json=ploads[i], |
| 62 | stream=False, |
| 63 | ) |
| 64 | k = list( |
| 65 | response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0") |
| 66 | ) |
| 67 | # print(k) |
| 68 | response_new_words = json.loads(k[-2].decode("utf-8"))["text"] |
| 69 | error_code = json.loads(k[-2].decode("utf-8"))["error_code"] |
no test coverage detected
searching dependent graphs…