MCPcopy
hub / github.com/QData/TextAttack / _create_model_from_args

Method _create_model_from_args

textattack/model_args.py:170–310  ·  view source on GitHub ↗

Given ``ModelArgs``, return specified ``textattack.models.wrappers.ModelWrapper`` object.

(cls, args)

Source from the content-addressed store, hash-verified

168
169 @classmethod
170 def _create_model_from_args(cls, args):
171 """Given ``ModelArgs``, return specified
172 ``textattack.models.wrappers.ModelWrapper`` object."""
173
174 assert isinstance(
175 args, cls
176 ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
177
178 if args.model_from_file:
179 # Support loading the model from a .py file where a model wrapper
180 # is instantiated.
181 colored_model_name = textattack.shared.utils.color_text(
182 args.model_from_file, color="blue", method="ansi"
183 )
184 textattack.shared.logger.info(
185 f"Loading model and tokenizer from file: {colored_model_name}"
186 )
187 if ARGS_SPLIT_TOKEN in args.model_from_file:
188 model_file, model_name = args.model_from_file.split(ARGS_SPLIT_TOKEN)
189 else:
190 _, model_name = args.model_from_file, "model"
191 try:
192 model_module = load_module_from_file(args.model_from_file)
193 except Exception:
194 raise ValueError(f"Failed to import file {args.model_from_file}.")
195 try:
196 model = getattr(model_module, model_name)
197 except AttributeError:
198 raise AttributeError(
199 f"Variable `{model_name}` not found in module {args.model_from_file}."
200 )
201
202 if not isinstance(model, textattack.models.wrappers.ModelWrapper):
203 raise TypeError(
204 f"Variable `{model_name}` must be of type "
205 f"``textattack.models.ModelWrapper``, got type {type(model)}."
206 )
207 elif (args.model in HUGGINGFACE_MODELS) or args.model_from_huggingface:
208 # Support loading models automatically from the HuggingFace model hub.
209
210 model_name = (
211 HUGGINGFACE_MODELS[args.model]
212 if (args.model in HUGGINGFACE_MODELS)
213 else args.model_from_huggingface
214 )
215 colored_model_name = textattack.shared.utils.color_text(
216 model_name, color="blue", method="ansi"
217 )
218 textattack.shared.logger.info(
219 f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}"
220 )
221 model = transformers.AutoModelForSequenceClassification.from_pretrained(
222 model_name
223 )
224 tokenizer = transformers.AutoTokenizer.from_pretrained(
225 model_name, use_fast=True
226 )
227 model = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)

Callers 4

runMethod · 0.45
runMethod · 0.45
runMethod · 0.45
test_model_on_datasetMethod · 0.45

Calls 3

load_module_from_fileFunction · 0.90
loadMethod · 0.80
from_pretrainedMethod · 0.45

Tested by 1

test_model_on_datasetMethod · 0.36