Called by .function_call method when there is an api_endpoint passed in the model constructor. Rather than execute the inference locally, it will be sent over API to inference server.
(self, context="", tool_type="", model_name="", params="", prompt="",
function=None, endpoint_base=None, api_key=None, get_logits=False)
| 9750 | return output |
| 9751 | |
| 9752 | def function_call_over_api_endpoint(self, context="", tool_type="", model_name="", params="", prompt="", |
| 9753 | function=None, endpoint_base=None, api_key=None, get_logits=False): |
| 9754 | |
| 9755 | """ Called by .function_call method when there is an api_endpoint passed in the model constructor. Rather |
| 9756 | than execute the inference locally, it will be sent over API to inference server. """ |
| 9757 | |
| 9758 | self.context = context |
| 9759 | self.tool_type = tool_type |
| 9760 | self.model_name = model_name |
| 9761 | |
| 9762 | # send to api agent server |
| 9763 | |
| 9764 | import ast |
| 9765 | import requests |
| 9766 | |
| 9767 | if endpoint_base: |
| 9768 | self.api_endpoint = endpoint_base |
| 9769 | |
| 9770 | if api_key: |
| 9771 | # e.g., "demo-test" |
| 9772 | self.api_key = api_key |
| 9773 | |
| 9774 | if not params: |
| 9775 | self.model_name = _ModelRegistry().get_llm_fx_mapping()[tool_type] |
| 9776 | mc = ModelCatalog().lookup_model_card(self.model_name) |
| 9777 | if "primary_keys" in mc: |
| 9778 | params = mc["primary_keys"] |
| 9779 | self.primary_keys = params |
| 9780 | |
| 9781 | if function: |
| 9782 | self.function = function |
| 9783 | |
| 9784 | self.prompt = prompt |
| 9785 | |
| 9786 | # preview before invoking rest api |
| 9787 | self.preview() |
| 9788 | |
| 9789 | url = self.api_endpoint + "{}".format("/agent") |
| 9790 | output_raw = requests.post(url, data={"model_name": self.model_name, "api_key": self.api_key, |
| 9791 | "tool_type": self.tool_type, |
| 9792 | "function": self.function, |
| 9793 | "params": self.primary_keys, "max_output": 50, |
| 9794 | "temperature": 0.0, "sample": False, "prompt": self.prompt, |
| 9795 | "context": self.context, "get_logits": True}) |
| 9796 | |
| 9797 | try: |
| 9798 | # output = ast.literal_eval(output_raw.text) |
| 9799 | output = json.loads(output_raw.text) |
| 9800 | if "logits" in output: |
| 9801 | logits = ast.literal_eval(output["logits"]) |
| 9802 | output["logits"] = logits |
| 9803 | |
| 9804 | if "output_tokens" in output: |
| 9805 | ot_int = [int(x) for x in output["output_tokens"]] |
| 9806 | output["output_tokens"] = ot_int |
| 9807 | |
| 9808 | # need to clean up logits |
| 9809 | except: |
no test coverage detected