Called by .inference method when there is an api_endpoint passed in the model constructor. Rather than execute the inference locally, it will be sent over API to inference server.
(self, prompt, context=None, inference_dict=None, get_logits=False)
| 3940 | return True |
| 3941 | |
| 3942 | def inference_over_api_endpoint(self, prompt, context=None, inference_dict=None, get_logits=False): |
| 3943 | |
| 3944 | """ Called by .inference method when there is an api_endpoint passed in the model constructor. Rather |
| 3945 | than execute the inference locally, it will be sent over API to inference server. """ |
| 3946 | |
| 3947 | import ast |
| 3948 | import requests |
| 3949 | |
| 3950 | self.prompt = prompt |
| 3951 | self.context = context |
| 3952 | |
| 3953 | self.preview() |
| 3954 | |
| 3955 | url = self.api_endpoint + "{}".format("/") |
| 3956 | output_raw = requests.post(url, data={"model_name": self.model_name, |
| 3957 | "question": self.prompt, |
| 3958 | "context": self.context, |
| 3959 | "api_key": self.api_key, |
| 3960 | "max_output": self.max_output, |
| 3961 | "temperature": self.temperature}) |
| 3962 | |
| 3963 | try: |
| 3964 | |
| 3965 | output = json.loads(output_raw.text) |
| 3966 | |
| 3967 | # will attempt to unpack logits - but catch any exceptions and skip |
| 3968 | if "logits" in output: |
| 3969 | try: |
| 3970 | logits = ast.literal_eval(output["logits"]) |
| 3971 | output["logits"] = logits |
| 3972 | except: |
| 3973 | output["logits"] = [] |
| 3974 | |
| 3975 | # will attempt to unpack output tokens - but catch any exceptions and skip |
| 3976 | if "output_tokens" in output: |
| 3977 | try: |
| 3978 | # alt: ot_int = [int(x) for x in output["output_tokens"]] |
| 3979 | # alt: output["output_tokens"] = ot_int |
| 3980 | output_tokens = ast.literal_eval(output["output_tokens"]) |
| 3981 | output["output_tokens"] = output_tokens |
| 3982 | except: |
| 3983 | output["output_tokens"] = [] |
| 3984 | |
| 3985 | except: |
| 3986 | logger.warning("warning: api inference was not successful") |
| 3987 | output = {"llm_response": "api-inference-error", "usage": {}} |
| 3988 | |
| 3989 | # output inference parameters |
| 3990 | self.llm_response = output["llm_response"] |
| 3991 | self.usage = output["usage"] |
| 3992 | self.final_prompt = prompt |
| 3993 | |
| 3994 | if "logits" in output: |
| 3995 | self.logits = output["logits"] |
| 3996 | if "output_tokens" in output: |
| 3997 | self.output_tokens = output["output_tokens"] |
| 3998 | |
| 3999 | self.register() |