(self, model=None, tokenizer=None, model_name=None, api_key=None, model_card=None,
prompt_wrapper=None, instruction_following=False, context_window=2048,
use_gpu_if_available=True, trust_remote_code=True, sample=True,max_output=100, temperature=0.3,
get_logits=False, api_endpoint=None, **kwargs)
| 8752 | # 2. passing a model_name only, which will then create the model and tokenizer |
| 8753 | |
| 8754 | def __init__(self, model=None, tokenizer=None, model_name=None, api_key=None, model_card=None, |
| 8755 | prompt_wrapper=None, instruction_following=False, context_window=2048, |
| 8756 | use_gpu_if_available=True, trust_remote_code=True, sample=True,max_output=100, temperature=0.3, |
| 8757 | get_logits=False, api_endpoint=None, **kwargs): |
| 8758 | |
| 8759 | super().__init__(**kwargs) |
| 8760 | |
| 8761 | self.model_class = "HFGenerativeModel" |
| 8762 | self.model_category = "generative" |
| 8763 | self.llm_response = None |
| 8764 | self.usage = None |
| 8765 | self.logits = None |
| 8766 | self.output_tokens = None |
| 8767 | self.final_prompt = None |
| 8768 | |
| 8769 | # pull in expected hf input |
| 8770 | self.model_name = model_name |
| 8771 | self.hf_tokenizer_name = model_name |
| 8772 | self.model = model |
| 8773 | self.tokenizer = tokenizer |
| 8774 | |
| 8775 | # new parameters |
| 8776 | self.sample=sample |
| 8777 | self.get_logits=get_logits |
| 8778 | self.auto_remediate_function_call_output = True |
| 8779 | |
| 8780 | # Function Call parameters |
| 8781 | self.model_card = model_card |
| 8782 | self.logits_record = [] |
| 8783 | self.output_tokens = [] |
| 8784 | self.top_logit_count = 10 |
| 8785 | self.primary_keys = None |
| 8786 | self.function = None |
| 8787 | self.fc_supported = False |
| 8788 | |
| 8789 | if model_card: |
| 8790 | |
| 8791 | if "primary_keys" in model_card: |
| 8792 | self.primary_keys = model_card["primary_keys"] |
| 8793 | |
| 8794 | if "function" in model_card: |
| 8795 | self.function = model_card["function"] |
| 8796 | |
| 8797 | if "function_call" in model_card: |
| 8798 | self.fc_supported = model_card["function_call"] |
| 8799 | |
| 8800 | # insert dynamic pytorch load here |
| 8801 | if not api_endpoint: |
| 8802 | |
| 8803 | global GLOBAL_TORCH_IMPORT |
| 8804 | if not GLOBAL_TORCH_IMPORT: |
| 8805 | if util.find_spec("torch"): |
| 8806 | |
| 8807 | try: |
| 8808 | global torch |
| 8809 | torch = importlib.import_module("torch") |
| 8810 | GLOBAL_TORCH_IMPORT = True |
| 8811 | except: |
nothing calls this directly
no test coverage detected