MCPcopy Index your code
hub / github.com/THUDM/AgentBench / sync_start_sample

Method sync_start_sample

src/server/tasks/knowledgegraph/task.py:71–213  ·  view source on GitHub ↗
(self, index: SampleIndex, session: Session)

Source from the content-addressed store, hash-verified

69 return await super().start_sample(index, session)
70
71 def sync_start_sample(self, index: SampleIndex, session: Session) -> TaskSampleExecutionResult:
72 self.logger.info(f'starting sample {index} with session id {session.id}')
73
74 data_item = self.inputs[index]
75 question = data_item['question']
76 entities = data_item['entities']
77 self.logger.info(f'[session {session.id}] Processing question: {question[:50]}...')
78
79 session_id, _, urls = self.env_controller.sync_start_session(ENV_SUBTYPE)
80 try:
81 sparql_url = urls[ENV_SUBTYPE]
82 sparql_executor = SparqlExecuter(sparql_url)
83 api = API(sparql_executor, session.id)
84
85 session.inject(ChatCompletionSystemMessageParam(
86 role='system',
87 content=INSTRUCTIONS.format(max_round=self.max_rounds)
88 ))
89 if self.one_shot:
90 session.inject(ONE_SHOT)
91 session.inject(ChatCompletionUserMessageParam(
92 role='user',
93 content=f'{question}\nEntities: [{", ".join([entity for entity in entities])}]'
94 ))
95
96 variables_list = []
97 for current_round in range(self.max_rounds):
98 response = session.sync_action()
99
100 tool_calls = response.messages[0].get('tool_calls') or []
101 if not tool_calls:
102 try:
103 final_message = response.messages[0].get('content') or ''
104 final_message = final_message.split("Observation:")[0]
105 final_message = final_message.replace("\\_", "_")
106 final_answer = re.findall(r'(?:Find|Final) Answer: #(\d+)', final_message)
107
108 if final_answer:
109 var_idx = int(final_answer[0])
110 answer_variable = variables_list[var_idx]
111
112 # base reward for submitting answer
113 predicted_answer = set(api.final_execute(answer_variable))
114 gold_answer = self.targets[index]
115
116 # calculate correctness and F1 score
117 is_correct = (len(gold_answer.intersection(predicted_answer)) == len(gold_answer) and
118 len(gold_answer.intersection(predicted_answer)) == len(predicted_answer))
119 f1_score = self._calculate_f1(predicted_answer, gold_answer)
120
121 session.inject(RewardHistoryItem(reward=int(is_correct), score=f1_score))
122 return TaskSampleExecutionResult(status=SampleStatus.COMPLETED)
123
124 except IndexError:
125 self.logger.info(f'[session {session.id}] invalid variable index')
126 return TaskSampleExecutionResult(status=SampleStatus.AGENT_VALIDATION_FAILED)
127
128 except Exception:

Callers

nothing calls this directly

Calls 6

final_executeMethod · 0.95
_calculate_f1Method · 0.95
SparqlExecuterClass · 0.85
APIClass · 0.85
intersectionMethod · 0.80

Tested by

no test coverage detected