Called by .function_call method when there is an api_endpoint passed in the model constructor. Rather than execute the inference locally, it will be sent over API to inference server.
(self, context="", tool_type="", model_name="", params="", prompt="",
function=None, endpoint_base=None, api_key=None, get_logits=False)
| 4001 | return output |
| 4002 | |
| 4003 | def function_call_over_api_endpoint(self, context="", tool_type="", model_name="", params="", prompt="", |
| 4004 | function=None, endpoint_base=None, api_key=None, get_logits=False): |
| 4005 | |
| 4006 | """ Called by .function_call method when there is an api_endpoint passed in the model constructor. Rather |
| 4007 | than execute the inference locally, it will be sent over API to inference server. """ |
| 4008 | |
| 4009 | # send to api agent server |
| 4010 | |
| 4011 | import ast |
| 4012 | import requests |
| 4013 | |
| 4014 | self.context = context |
| 4015 | self.tool_type = tool_type |
| 4016 | if model_name: |
| 4017 | self.model_name = model_name |
| 4018 | |
| 4019 | self.preview() |
| 4020 | |
| 4021 | if endpoint_base: |
| 4022 | self.api_endpoint = endpoint_base |
| 4023 | |
| 4024 | if api_key: |
| 4025 | # e.g., "demo-test" |
| 4026 | self.api_key = api_key |
| 4027 | |
| 4028 | if not params: |
| 4029 | model_name = _ModelRegistry().get_llm_fx_mapping()[tool_type] |
| 4030 | mc = ModelCatalog().lookup_model_card(model_name) |
| 4031 | if "primary_keys" in mc: |
| 4032 | params = mc["primary_keys"] |
| 4033 | self.primary_keys = params |
| 4034 | |
| 4035 | if function: |
| 4036 | self.function = function |
| 4037 | |
| 4038 | self.context = context |
| 4039 | self.prompt = prompt |
| 4040 | |
| 4041 | url = self.api_endpoint + "{}".format("/agent") |
| 4042 | output_raw = requests.post(url, data={"model_name": self.model_name, "api_key": self.api_key, |
| 4043 | "tool_type": self.tool_type, |
| 4044 | "function": self.function, |
| 4045 | "params": self.primary_keys, "max_output": 50, |
| 4046 | "temperature": 0.0, "sample": False, "prompt": self.prompt, |
| 4047 | "context": self.context, "get_logits": True}) |
| 4048 | |
| 4049 | try: |
| 4050 | output = json.loads(output_raw.text) |
| 4051 | if "logits" in output: |
| 4052 | logits = ast.literal_eval(output["logits"]) |
| 4053 | output["logits"] = logits |
| 4054 | |
| 4055 | if "output_tokens" in output: |
| 4056 | ot_int = [int(x) for x in output["output_tokens"]] |
| 4057 | output["output_tokens"] = ot_int |
| 4058 | |
| 4059 | # need to clean up logits |
| 4060 |
no test coverage detected