Retrieves and instantiates a Pytorch Generative model. Takes a model_name as input, which is assumed to map to the Huggingface repository name - this name is not necessarily the same as the LLMWare model card, which is used to lookup the model in model_configs -> the model_name u
(self, model_name, **kwargs)
| 12050 | self.custom_loader = custom_loader |
| 12051 | |
| 12052 | def get_generative_model(self, model_name, **kwargs): |
| 12053 | |
| 12054 | """ Retrieves and instantiates a Pytorch Generative model. Takes a model_name as input, which is |
| 12055 | assumed to map to the Huggingface repository name - this name is not necessarily the same as the |
| 12056 | LLMWare model card, which is used to lookup the model in model_configs -> the model_name used here |
| 12057 | should be the hf_repo attribute on the model card. """ |
| 12058 | |
| 12059 | # will return None if no model found |
| 12060 | model = None |
| 12061 | |
| 12062 | self.model_name=model_name |
| 12063 | |
| 12064 | if self.custom_loader: |
| 12065 | model = self.custom_loader.loader(self.model_name, |
| 12066 | self.api_key,self.trust_remote_code,caller="generative_model",**kwargs) |
| 12067 | |
| 12068 | else: |
| 12069 | |
| 12070 | try: |
| 12071 | # will wrap in Exception if import fails |
| 12072 | from transformers import AutoModelForCausalLM, AutoTokenizer |
| 12073 | except ImportError: |
| 12074 | raise DependencyNotInstalledException("transformers") |
| 12075 | |
| 12076 | # insert dynamic pytorch load here |
| 12077 | global GLOBAL_TORCH_IMPORT |
| 12078 | if not GLOBAL_TORCH_IMPORT: |
| 12079 | |
| 12080 | logger.debug("Pytorch loader - local dynamic load of torch here") |
| 12081 | if util.find_spec("torch"): |
| 12082 | |
| 12083 | try: |
| 12084 | global torch |
| 12085 | torch = importlib.import_module("torch") |
| 12086 | GLOBAL_TORCH_IMPORT = True |
| 12087 | except: |
| 12088 | raise LLMWareException(message="Exception: could not load torch module.") |
| 12089 | |
| 12090 | else: |
| 12091 | raise LLMWareException(message="Exception: need to import torch to use this class.") |
| 12092 | |
| 12093 | if self.api_key: |
| 12094 | |
| 12095 | if torch.cuda.is_available(): |
| 12096 | model = AutoModelForCausalLM.from_pretrained(model_name, token=self.api_key, |
| 12097 | trust_remote_code=self.trust_remote_code, |
| 12098 | torch_dtype="auto") |
| 12099 | else: |
| 12100 | model = AutoModelForCausalLM.from_pretrained(model_name, token=self.api_key, |
| 12101 | trust_remote_code=self.trust_remote_code) |
| 12102 | |
| 12103 | else: |
| 12104 | if torch.cuda.is_available(): |
| 12105 | model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=self.trust_remote_code, |
| 12106 | torch_dtype="auto") |
| 12107 | else: |
| 12108 | model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=self.trust_remote_code) |
| 12109 |
no test coverage detected