Execute the write_test tool. Args: test_description : The specification description. test_file_name: The name of the file where the generated tests will be saved. Returns: Generated unit tests or error message.
(self, test_description: str, test_file_name: str)
| 58 | arbitrary_types_allowed = True |
| 59 | |
| 60 | def _execute(self, test_description: str, test_file_name: str) -> str: |
| 61 | """ |
| 62 | Execute the write_test tool. |
| 63 | |
| 64 | Args: |
| 65 | test_description : The specification description. |
| 66 | test_file_name: The name of the file where the generated tests will be saved. |
| 67 | |
| 68 | Returns: |
| 69 | Generated unit tests or error message. |
| 70 | """ |
| 71 | prompt = PromptReader.read_tools_prompt(__file__, "write_test.txt") |
| 72 | prompt = prompt.replace("{goals}", AgentPromptBuilder.add_list_items_to_string(self.goals)) |
| 73 | prompt = prompt.replace("{test_description}", test_description) |
| 74 | |
| 75 | spec_response = self.tool_response_manager.get_last_response("WriteSpecTool") |
| 76 | if spec_response != "": |
| 77 | prompt = prompt.replace("{spec}", |
| 78 | "Please generate unit tests based on the following specification description:\n" + spec_response) |
| 79 | else: |
| 80 | spec_response = self.tool_response_manager.get_last_response() |
| 81 | if spec_response != "": |
| 82 | prompt = prompt.replace("{spec}", |
| 83 | "Please generate unit tests based on the following specification description:\n" + spec_response) |
| 84 | |
| 85 | messages = [{"role": "system", "content": prompt}] |
| 86 | logger.info(prompt) |
| 87 | |
| 88 | organisation = Agent.find_org_by_agent_id(self.toolkit_config.session, agent_id=self.agent_id) |
| 89 | total_tokens = TokenCounter.count_message_tokens(messages, self.llm.get_model()) |
| 90 | token_limit = TokenCounter(session=self.toolkit_config.session, organisation_id=organisation.id).token_limit(self.llm.get_model()) |
| 91 | |
| 92 | result = self.llm.chat_completion(messages, max_tokens=(token_limit - total_tokens - 100)) |
| 93 | |
| 94 | if 'error' in result and result['message'] is not None: |
| 95 | ErrorHandler.handle_openai_errors(self.toolkit_config.session, self.agent_id, self.agent_execution_id, result['message']) |
| 96 | |
| 97 | regex = r"(\S+?)\n```\S*\n(.+?)```" |
| 98 | matches = re.finditer(regex, result["content"], re.DOTALL) |
| 99 | |
| 100 | file_names = [] |
| 101 | # Save each file |
| 102 | |
| 103 | for match in matches: |
| 104 | # Get the filename |
| 105 | file_name = re.sub(r'[<>"|?*]', "", match.group(1)) |
| 106 | code = match.group(2) |
| 107 | if not file_name.strip(): |
| 108 | continue |
| 109 | |
| 110 | file_names.append(file_name) |
| 111 | save_result = self.resource_manager.write_file(file_name, code) |
| 112 | if save_result.startswith("Error"): |
| 113 | return save_result |
| 114 | |
| 115 | # Save the tests to a file |
| 116 | # save_result = self.resource_manager.write_file(test_file_name, code_content) |
| 117 | if not result["content"].startswith("Error"): |
no test coverage detected