MCPcopy
hub / github.com/llmware-ai/llmware / get_generative_model

Method get_generative_model

llmware/models.py:12052–12110  ·  view source on GitHub ↗

Retrieves and instantiates a Pytorch Generative model. Takes a model_name as input, which is assumed to map to the Huggingface repository name - this name is not necessarily the same as the LLMWare model card, which is used to lookup the model in model_configs -> the model_name u

(self, model_name, **kwargs)

Source from the content-addressed store, hash-verified

12050 self.custom_loader = custom_loader
12051
12052 def get_generative_model(self, model_name, **kwargs):
12053
12054 """ Retrieves and instantiates a Pytorch Generative model. Takes a model_name as input, which is
12055 assumed to map to the Huggingface repository name - this name is not necessarily the same as the
12056 LLMWare model card, which is used to lookup the model in model_configs -> the model_name used here
12057 should be the hf_repo attribute on the model card. """
12058
12059 # will return None if no model found
12060 model = None
12061
12062 self.model_name=model_name
12063
12064 if self.custom_loader:
12065 model = self.custom_loader.loader(self.model_name,
12066 self.api_key,self.trust_remote_code,caller="generative_model",**kwargs)
12067
12068 else:
12069
12070 try:
12071 # will wrap in Exception if import fails
12072 from transformers import AutoModelForCausalLM, AutoTokenizer
12073 except ImportError:
12074 raise DependencyNotInstalledException("transformers")
12075
12076 # insert dynamic pytorch load here
12077 global GLOBAL_TORCH_IMPORT
12078 if not GLOBAL_TORCH_IMPORT:
12079
12080 logger.debug("Pytorch loader - local dynamic load of torch here")
12081 if util.find_spec("torch"):
12082
12083 try:
12084 global torch
12085 torch = importlib.import_module("torch")
12086 GLOBAL_TORCH_IMPORT = True
12087 except:
12088 raise LLMWareException(message="Exception: could not load torch module.")
12089
12090 else:
12091 raise LLMWareException(message="Exception: need to import torch to use this class.")
12092
12093 if self.api_key:
12094
12095 if torch.cuda.is_available():
12096 model = AutoModelForCausalLM.from_pretrained(model_name, token=self.api_key,
12097 trust_remote_code=self.trust_remote_code,
12098 torch_dtype="auto")
12099 else:
12100 model = AutoModelForCausalLM.from_pretrained(model_name, token=self.api_key,
12101 trust_remote_code=self.trust_remote_code)
12102
12103 else:
12104 if torch.cuda.is_available():
12105 model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=self.trust_remote_code,
12106 torch_dtype="auto")
12107 else:
12108 model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=self.trust_remote_code)
12109

Callers 2

load_modelMethod · 0.95
__init__Method · 0.95

Calls 3

LLMWareExceptionClass · 0.90
loaderMethod · 0.80

Tested by

no test coverage detected