(messages, api_key, api_type, api_endpoint, return_planning = False, return_results = False)
| 889 | return True |
| 890 | |
| 891 | def chat_huggingface(messages, api_key, api_type, api_endpoint, return_planning = False, return_results = False): |
| 892 | start = time.time() |
| 893 | context = messages[:-1] |
| 894 | input = messages[-1]["content"] |
| 895 | logger.info("*"*80) |
| 896 | logger.info(f"input: {input}") |
| 897 | |
| 898 | task_str = parse_task(context, input, api_key, api_type, api_endpoint) |
| 899 | |
| 900 | if "error" in task_str: |
| 901 | record_case(success=False, **{"input": input, "task": task_str, "reason": f"task parsing error: {task_str['error']['message']}", "op":"report message"}) |
| 902 | return {"message": task_str["error"]["message"]} |
| 903 | |
| 904 | task_str = task_str.strip() |
| 905 | logger.info(task_str) |
| 906 | |
| 907 | try: |
| 908 | tasks = json.loads(task_str) |
| 909 | except Exception as e: |
| 910 | logger.debug(e) |
| 911 | response = chitchat(messages, api_key, api_type, api_endpoint) |
| 912 | record_case(success=False, **{"input": input, "task": task_str, "reason": "task parsing fail", "op":"chitchat"}) |
| 913 | return {"message": response} |
| 914 | |
| 915 | if task_str == "[]": # using LLM response for empty task |
| 916 | record_case(success=False, **{"input": input, "task": [], "reason": "task parsing fail: empty", "op": "chitchat"}) |
| 917 | response = chitchat(messages, api_key, api_type, api_endpoint) |
| 918 | return {"message": response} |
| 919 | |
| 920 | if len(tasks) == 1 and tasks[0]["task"] in ["summarization", "translation", "conversational", "text-generation", "text2text-generation"]: |
| 921 | record_case(success=True, **{"input": input, "task": tasks, "reason": "chitchat tasks", "op": "chitchat"}) |
| 922 | response = chitchat(messages, api_key, api_type, api_endpoint) |
| 923 | return {"message": response} |
| 924 | |
| 925 | tasks = unfold(tasks) |
| 926 | tasks = fix_dep(tasks) |
| 927 | logger.debug(tasks) |
| 928 | |
| 929 | if return_planning: |
| 930 | return tasks |
| 931 | |
| 932 | results = {} |
| 933 | threads = [] |
| 934 | tasks = tasks[:] |
| 935 | d = dict() |
| 936 | retry = 0 |
| 937 | while True: |
| 938 | num_thread = len(threads) |
| 939 | for task in tasks: |
| 940 | # logger.debug(f"d.keys(): {d.keys()}, dep: {dep}") |
| 941 | for dep_id in task["dep"]: |
| 942 | if dep_id >= task["id"]: |
| 943 | task["dep"] = [-1] |
| 944 | break |
| 945 | dep = task["dep"] |
| 946 | if dep[0] == -1 or len(list(set(dep).intersection(d.keys()))) == len(dep): |
| 947 | tasks.remove(task) |
| 948 | thread = threading.Thread(target=run_task, args=(input, task, d, api_key, api_type, api_endpoint)) |
no test coverage detected