(method,data_dict)
| 72 | } |
| 73 | } |
| 74 | def process_invalid_data(method,data_dict): |
| 75 | answer_generation = data_dict['answer_generation'] |
| 76 | functions = answer_generation['function'] |
| 77 | query = answer_generation['query'] |
| 78 | eg = ExecutionGraph() |
| 79 | last_node = generate_init_message_node(eg,functions,query) |
| 80 | if 'CoT' in method or 'cot' in method: |
| 81 | trail = random.choice(data_dict["trys"]) |
| 82 | |
| 83 | |
| 84 | index = 0 |
| 85 | while index < len(trail['chain']): |
| 86 | message = trail['chain'][index] |
| 87 | if message['node_type'] == 'Action': |
| 88 | node = ExecutionNode(role='tool', message={ |
| 89 | 'name':message['description'], |
| 90 | 'arguments':(trail['chain'][index+1]['description']), |
| 91 | 'response':(trail['chain'][index+1]['observation'])}) |
| 92 | |
| 93 | index = index + 1 |
| 94 | elif message['node_type'] == 'Thought': |
| 95 | node = ExecutionNode(role='assistant', |
| 96 | message=message['description']) |
| 97 | else: |
| 98 | raise NotImplementedError(f"Unknown node_type: {message['node_type']}") |
| 99 | index = index + 1 |
| 100 | |
| 101 | eg.add_node(node) |
| 102 | eg[last_node,node] = None |
| 103 | last_node = node |
| 104 | eg = eg.reduce_graph_to_sequence() |
| 105 | |
| 106 | elif 'DFS' in method or 'dfs' in method: |
| 107 | |
| 108 | def DFS(root): |
| 109 | if len(root['children']) == 0: |
| 110 | node = ExecutionNode(role=root['node_type'],message=root) |
| 111 | eg.add_node(node) |
| 112 | return node |
| 113 | else: |
| 114 | child_nodes = [DFS(node) for node in root['children']] |
| 115 | root['children'] = None |
| 116 | root_node = ExecutionNode(role=root['node_type'],message=root) |
| 117 | eg.add_node(root_node) |
| 118 | for child_node in child_nodes: |
| 119 | eg.add_edge(root_node,child_node) |
| 120 | return root_node |
| 121 | for node in data_dict['tree']['tree']['children']: |
| 122 | eg[last_node,DFS(node)] = None |
| 123 | |
| 124 | |
| 125 | # purify the graph |
| 126 | def purify_graph(node:ExecutionNode): |
| 127 | if node.role == 'Action': |
| 128 | adj_nodes = eg.get_adjacent_node(node) |
| 129 | for adj_node in adj_nodes: |
| 130 | adj_node = eg[adj_node] |
| 131 | if adj_node.role == 'Action Input': |
no test coverage detected