Called by .function_call method when there is an api_endpoint passed in the model constructor. Rather than execute the inference locally, it will be sent over API to inference server.
(self, context="", tool_type="", model_name="", params="", prompt="",
function=None, endpoint_base=None, api_key=None, get_logits=False)
| 5143 | return output_response |
| 5144 | |
| 5145 | def function_call_over_api_endpoint(self, context="", tool_type="", model_name="", params="", prompt="", |
| 5146 | function=None, endpoint_base=None, api_key=None, get_logits=False): |
| 5147 | |
| 5148 | """ Called by .function_call method when there is an api_endpoint passed in the model constructor. Rather |
| 5149 | than execute the inference locally, it will be sent over API to inference server. """ |
| 5150 | |
| 5151 | # send to api agent server |
| 5152 | |
| 5153 | self.context = context |
| 5154 | self.tool_type = tool_type |
| 5155 | self.prompt = prompt |
| 5156 | |
| 5157 | import ast |
| 5158 | import requests |
| 5159 | |
| 5160 | if endpoint_base: |
| 5161 | self.api_endpoint = endpoint_base |
| 5162 | |
| 5163 | if api_key: |
| 5164 | # e.g., "demo-test" |
| 5165 | self.api_key = api_key |
| 5166 | |
| 5167 | if not params: |
| 5168 | |
| 5169 | self.model_name = _ModelRegistry().get_llm_fx_mapping()[tool_type] |
| 5170 | mc = ModelCatalog().lookup_model_card(self.model_name) |
| 5171 | if "primary_keys" in mc: |
| 5172 | params = mc["primary_keys"] |
| 5173 | self.primary_keys = params |
| 5174 | |
| 5175 | if function: |
| 5176 | self.function = function |
| 5177 | |
| 5178 | self.context = context |
| 5179 | |
| 5180 | self.preview() |
| 5181 | |
| 5182 | url = self.api_endpoint + "{}".format("/agent") |
| 5183 | output_raw = requests.post(url, data={"model_name": self.model_name, "api_key": self.api_key, |
| 5184 | "tool_type": self.tool_type, |
| 5185 | "function": self.function, "params": self.primary_keys, "max_output": 50, |
| 5186 | "temperature": 0.0, "sample": False, "prompt": self.prompt, |
| 5187 | "context": self.context, "get_logits": True}) |
| 5188 | |
| 5189 | try: |
| 5190 | # output = ast.literal_eval(output_raw.text) |
| 5191 | output = json.loads(output_raw.text) |
| 5192 | if "logits" in output: |
| 5193 | logits = ast.literal_eval(output["logits"]) |
| 5194 | output["logits"] = logits |
| 5195 | |
| 5196 | if "output_tokens" in output: |
| 5197 | ot_int = [int(x) for x in output["output_tokens"]] |
| 5198 | output["output_tokens"] = ot_int |
| 5199 | |
| 5200 | except: |
| 5201 | logger.warning("OVGenerativeModel - api inference was not successful") |
| 5202 | output = {} |
no test coverage detected