(self, index: SampleIndex, session: Session)
| 69 | return await super().start_sample(index, session) |
| 70 | |
| 71 | def sync_start_sample(self, index: SampleIndex, session: Session) -> TaskSampleExecutionResult: |
| 72 | self.logger.info(f'starting sample {index} with session id {session.id}') |
| 73 | |
| 74 | data_item = self.inputs[index] |
| 75 | question = data_item['question'] |
| 76 | entities = data_item['entities'] |
| 77 | self.logger.info(f'[session {session.id}] Processing question: {question[:50]}...') |
| 78 | |
| 79 | session_id, _, urls = self.env_controller.sync_start_session(ENV_SUBTYPE) |
| 80 | try: |
| 81 | sparql_url = urls[ENV_SUBTYPE] |
| 82 | sparql_executor = SparqlExecuter(sparql_url) |
| 83 | api = API(sparql_executor, session.id) |
| 84 | |
| 85 | session.inject(ChatCompletionSystemMessageParam( |
| 86 | role='system', |
| 87 | content=INSTRUCTIONS.format(max_round=self.max_rounds) |
| 88 | )) |
| 89 | if self.one_shot: |
| 90 | session.inject(ONE_SHOT) |
| 91 | session.inject(ChatCompletionUserMessageParam( |
| 92 | role='user', |
| 93 | content=f'{question}\nEntities: [{", ".join([entity for entity in entities])}]' |
| 94 | )) |
| 95 | |
| 96 | variables_list = [] |
| 97 | for current_round in range(self.max_rounds): |
| 98 | response = session.sync_action() |
| 99 | |
| 100 | tool_calls = response.messages[0].get('tool_calls') or [] |
| 101 | if not tool_calls: |
| 102 | try: |
| 103 | final_message = response.messages[0].get('content') or '' |
| 104 | final_message = final_message.split("Observation:")[0] |
| 105 | final_message = final_message.replace("\\_", "_") |
| 106 | final_answer = re.findall(r'(?:Find|Final) Answer: #(\d+)', final_message) |
| 107 | |
| 108 | if final_answer: |
| 109 | var_idx = int(final_answer[0]) |
| 110 | answer_variable = variables_list[var_idx] |
| 111 | |
| 112 | # base reward for submitting answer |
| 113 | predicted_answer = set(api.final_execute(answer_variable)) |
| 114 | gold_answer = self.targets[index] |
| 115 | |
| 116 | # calculate correctness and F1 score |
| 117 | is_correct = (len(gold_answer.intersection(predicted_answer)) == len(gold_answer) and |
| 118 | len(gold_answer.intersection(predicted_answer)) == len(predicted_answer)) |
| 119 | f1_score = self._calculate_f1(predicted_answer, gold_answer) |
| 120 | |
| 121 | session.inject(RewardHistoryItem(reward=int(is_correct), score=f1_score)) |
| 122 | return TaskSampleExecutionResult(status=SampleStatus.COMPLETED) |
| 123 | |
| 124 | except IndexError: |
| 125 | self.logger.info(f'[session {session.id}] invalid variable index') |
| 126 | return TaskSampleExecutionResult(status=SampleStatus.AGENT_VALIDATION_FAILED) |
| 127 | |
| 128 | except Exception: |
nothing calls this directly
no test coverage detected