(method,data_dict)
| 99 | } |
| 100 | |
| 101 | def process_invalid_data(method,data_dict): |
| 102 | answer_generation = data_dict['answer_generation'] |
| 103 | functions = answer_generation['function'] |
| 104 | query = answer_generation['query'] |
| 105 | eg = ExecutionGraph() |
| 106 | last_node = generate_init_message_node(eg,functions,query) |
| 107 | if 'CoT' in method: |
| 108 | trail = random.choice(data_dict["trys"]) |
| 109 | index = 0 |
| 110 | while index < len(trail['chain']): |
| 111 | message = trail['chain'][index] |
| 112 | if message['node_type'] == 'Action': |
| 113 | node = ExecutionNode(role='tool', message={ |
| 114 | 'name':message['description'], |
| 115 | 'arguments':(trail['chain'][index+1]['description']), |
| 116 | 'response':(trail['chain'][index+1]['observation'])}) |
| 117 | |
| 118 | index = index + 1 |
| 119 | elif message['node_type'] == 'Thought': |
| 120 | node = ExecutionNode(role='assistant', |
| 121 | message=message['description']) |
| 122 | else: |
| 123 | raise NotImplementedError(f"Unknown node_type: {message['node_type']}") |
| 124 | index = index + 1 |
| 125 | |
| 126 | eg.add_node(node) |
| 127 | eg[last_node,node] = None |
| 128 | last_node = node |
| 129 | eg = eg.reduce_graph_to_sequence() |
| 130 | |
| 131 | elif 'DFS' in method: |
| 132 | |
| 133 | def DFS(root): |
| 134 | if len(root['children']) == 0: |
| 135 | node = ExecutionNode(role=root['node_type'],message=root) |
| 136 | eg.add_node(node) |
| 137 | return node |
| 138 | else: |
| 139 | child_nodes = [DFS(node) for node in root['children']] |
| 140 | root['children'] = None |
| 141 | root_node = ExecutionNode(role=root['node_type'],message=root) |
| 142 | eg.add_node(root_node) |
| 143 | for child_node in child_nodes: |
| 144 | eg.add_edge(root_node,child_node) |
| 145 | return root_node |
| 146 | for node in data_dict['tree']['tree']['children']: |
| 147 | eg[last_node,DFS(node)] = None |
| 148 | |
| 149 | |
| 150 | # purify the graph |
| 151 | def purify_graph(node:ExecutionNode): |
| 152 | if node.role == 'Action': |
| 153 | adj_nodes = eg.get_adjacent_node(node) |
| 154 | for adj_node in adj_nodes: |
| 155 | adj_node = eg[adj_node] |
| 156 | if adj_node.role == 'Action Input': |
| 157 | node.role = 'tool' |
| 158 | node.message = { |
nothing calls this directly
no test coverage detected