This is the key inference method for SLIM models - takes a context passage and a key list which is packaged in the prompt as the keys for the dictionary output
(self, context, function=None, params=None, get_logits=True,
temperature=-99, max_output=None)
| 9396 | return top_logits |
| 9397 | |
| 9398 | def function_call(self, context, function=None, params=None, get_logits=True, |
| 9399 | temperature=-99, max_output=None): |
| 9400 | |
| 9401 | """ This is the key inference method for SLIM models - takes a context passage and a key list |
| 9402 | which is packaged in the prompt as the keys for the dictionary output""" |
| 9403 | |
| 9404 | self.context = context |
| 9405 | |
| 9406 | # only assign self.function if a function has been passed in the call |
| 9407 | if function: |
| 9408 | self.function = function |
| 9409 | |
| 9410 | if not self.fc_supported: |
| 9411 | logger.warning("HFGenerativeModel - loaded model does not support function calls. " |
| 9412 | "Please either use the standard .inference method with this model, or use a " |
| 9413 | "model that has 'function_calls' key set to True in its model card.") |
| 9414 | return [] |
| 9415 | |
| 9416 | # reset and start from scratch with new function call |
| 9417 | self.output_tokens = [] |
| 9418 | self.logits_record = [] |
| 9419 | |
| 9420 | if temperature != -99: |
| 9421 | self.temperature = temperature |
| 9422 | |
| 9423 | if max_output: |
| 9424 | self.target_requested_output_tokens = max_output |
| 9425 | |
| 9426 | if get_logits: |
| 9427 | self.get_logits = get_logits |
| 9428 | |
| 9429 | if params: |
| 9430 | self.primary_keys = params |
| 9431 | |
| 9432 | # call to preview (not implemented by default) |
| 9433 | self.preview() |
| 9434 | |
| 9435 | if not self.primary_keys: |
| 9436 | logger.warning("warning: function call - no keys provided - function call may yield unpredictable results") |
| 9437 | |
| 9438 | # START - route to api endpoint |
| 9439 | |
| 9440 | if self.api_endpoint: |
| 9441 | return self.function_call_over_api_endpoint(model_name=self.model_name, |
| 9442 | context=self.context,params=self.primary_keys, |
| 9443 | function=self.function, |
| 9444 | api_key=self.api_key,get_logits=self.get_logits) |
| 9445 | |
| 9446 | # END - route to api endpoint |
| 9447 | |
| 9448 | prompt = self.fc_prompt_engineer(self.context, params=self.primary_keys, function=self.function) |
| 9449 | |
| 9450 | # second - tokenize to get the input_ids |
| 9451 | |
| 9452 | tokenizer_output = self.tokenizer.encode(prompt) |
| 9453 | input_token_len = len(tokenizer_output) |
| 9454 | input_ids = torch.tensor(tokenizer_output).unsqueeze(0) |
| 9455 |
nothing calls this directly
no test coverage detected