MCPcopy
hub / github.com/QData/TextAttack / _create_attack_from_args

Method _create_attack_from_args

textattack/attack_args.py:700–764  ·  view source on GitHub ↗

Given ``CommandLineArgs`` and ``ModelWrapper``, return specified ``Attack`` object.

(cls, args, model_wrapper)

Source from the content-addressed store, hash-verified

698
699 @classmethod
700 def _create_attack_from_args(cls, args, model_wrapper):
701 """Given ``CommandLineArgs`` and ``ModelWrapper``, return specified
702 ``Attack`` object."""
703
704 assert isinstance(
705 args, cls
706 ), f"Expect args to be of type `{type(cls)}`, but got type `{type(args)}`."
707
708 if args.attack_recipe:
709 if ARGS_SPLIT_TOKEN in args.attack_recipe:
710 recipe_name, params = args.attack_recipe.split(ARGS_SPLIT_TOKEN)
711 if recipe_name not in ATTACK_RECIPE_NAMES:
712 raise ValueError(f"Error: unsupported recipe {recipe_name}")
713 recipe = eval(
714 f"{ATTACK_RECIPE_NAMES[recipe_name]}.build(model_wrapper, {params})"
715 )
716 elif args.attack_recipe in ATTACK_RECIPE_NAMES:
717 recipe = eval(
718 f"{ATTACK_RECIPE_NAMES[args.attack_recipe]}.build(model_wrapper)"
719 )
720 else:
721 raise ValueError(f"Invalid recipe {args.attack_recipe}")
722 if args.query_budget:
723 recipe.goal_function.query_budget = args.query_budget
724 recipe.goal_function.model_cache_size = args.model_cache_size
725 recipe.goal_function.batch_size = args.model_batch_size
726 recipe.constraint_cache_size = args.constraint_cache_size
727 return recipe
728 elif args.attack_from_file:
729 if ARGS_SPLIT_TOKEN in args.attack_from_file:
730 attack_file, attack_name = args.attack_from_file.split(ARGS_SPLIT_TOKEN)
731 else:
732 attack_file, attack_name = args.attack_from_file, "attack"
733 attack_module = load_module_from_file(attack_file)
734 if not hasattr(attack_module, attack_name):
735 raise ValueError(
736 f"Loaded `{attack_file}` but could not find `{attack_name}`."
737 )
738 attack_func = getattr(attack_module, attack_name)
739 return attack_func(model_wrapper)
740 else:
741 goal_function = cls._create_goal_function_from_args(args, model_wrapper)
742 transformation = cls._create_transformation_from_args(args, model_wrapper)
743 constraints = cls._create_constraints_from_args(args)
744 if ARGS_SPLIT_TOKEN in args.search_method:
745 search_name, params = args.search_method.split(ARGS_SPLIT_TOKEN)
746 if search_name not in SEARCH_METHOD_CLASS_NAMES:
747 raise ValueError(f"Error: unsupported search {search_name}")
748 search_method = eval(
749 f"{SEARCH_METHOD_CLASS_NAMES[search_name]}({params})"
750 )
751 elif args.search_method in SEARCH_METHOD_CLASS_NAMES:
752 search_method = eval(
753 f"{SEARCH_METHOD_CLASS_NAMES[args.search_method]}()"
754 )
755 else:
756 raise ValueError(f"Error: unsupported attack {args.search_method}")
757

Callers 3

runMethod · 0.45
runMethod · 0.45
runMethod · 0.45

Tested by

no test coverage detected