Given ``ModelArgs``, return specified ``textattack.models.wrappers.ModelWrapper`` object.
(cls, args)
| 168 | |
| 169 | @classmethod |
| 170 | def _create_model_from_args(cls, args): |
| 171 | """Given ``ModelArgs``, return specified |
| 172 | ``textattack.models.wrappers.ModelWrapper`` object.""" |
| 173 | |
| 174 | assert isinstance( |
| 175 | args, cls |
| 176 | ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`." |
| 177 | |
| 178 | if args.model_from_file: |
| 179 | # Support loading the model from a .py file where a model wrapper |
| 180 | # is instantiated. |
| 181 | colored_model_name = textattack.shared.utils.color_text( |
| 182 | args.model_from_file, color="blue", method="ansi" |
| 183 | ) |
| 184 | textattack.shared.logger.info( |
| 185 | f"Loading model and tokenizer from file: {colored_model_name}" |
| 186 | ) |
| 187 | if ARGS_SPLIT_TOKEN in args.model_from_file: |
| 188 | model_file, model_name = args.model_from_file.split(ARGS_SPLIT_TOKEN) |
| 189 | else: |
| 190 | _, model_name = args.model_from_file, "model" |
| 191 | try: |
| 192 | model_module = load_module_from_file(args.model_from_file) |
| 193 | except Exception: |
| 194 | raise ValueError(f"Failed to import file {args.model_from_file}.") |
| 195 | try: |
| 196 | model = getattr(model_module, model_name) |
| 197 | except AttributeError: |
| 198 | raise AttributeError( |
| 199 | f"Variable `{model_name}` not found in module {args.model_from_file}." |
| 200 | ) |
| 201 | |
| 202 | if not isinstance(model, textattack.models.wrappers.ModelWrapper): |
| 203 | raise TypeError( |
| 204 | f"Variable `{model_name}` must be of type " |
| 205 | f"``textattack.models.ModelWrapper``, got type {type(model)}." |
| 206 | ) |
| 207 | elif (args.model in HUGGINGFACE_MODELS) or args.model_from_huggingface: |
| 208 | # Support loading models automatically from the HuggingFace model hub. |
| 209 | |
| 210 | model_name = ( |
| 211 | HUGGINGFACE_MODELS[args.model] |
| 212 | if (args.model in HUGGINGFACE_MODELS) |
| 213 | else args.model_from_huggingface |
| 214 | ) |
| 215 | colored_model_name = textattack.shared.utils.color_text( |
| 216 | model_name, color="blue", method="ansi" |
| 217 | ) |
| 218 | textattack.shared.logger.info( |
| 219 | f"Loading pre-trained model from HuggingFace model repository: {colored_model_name}" |
| 220 | ) |
| 221 | model = transformers.AutoModelForSequenceClassification.from_pretrained( |
| 222 | model_name |
| 223 | ) |
| 224 | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 225 | model_name, use_fast=True |
| 226 | ) |
| 227 | model = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) |