| 24 | |
| 25 | class StarChatArgumentParser(HfArgumentParser): |
| 26 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: |
| 27 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) |
| 28 | |
| 29 | outputs = [] |
| 30 | # strip other args list into dict of key-value pairs |
| 31 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} |
| 32 | used_args = {} |
| 33 | |
| 34 | # overwrite the default/loaded value with the value provided to the command line |
| 35 | # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 |
| 36 | for data_yaml, data_class in zip(arg_list, self.dataclass_types): |
| 37 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} |
| 38 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} |
| 39 | for arg, val in other_args.items(): |
| 40 | # add only if in keys |
| 41 | if arg in keys: |
| 42 | base_type = data_yaml.__dataclass_fields__[arg].type |
| 43 | inputs[arg] = val |
| 44 | |
| 45 | # cast type for ints, floats, and bools (default to strings) |
| 46 | if base_type in [int, float, bool]: |
| 47 | inputs[arg] = base_type(val) |
| 48 | |
| 49 | # add to used-args so we can check if double add |
| 50 | if arg not in used_args: |
| 51 | used_args[arg] = val |
| 52 | else: |
| 53 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") |
| 54 | |
| 55 | obj = data_class(**inputs) |
| 56 | outputs.append(obj) |
| 57 | |
| 58 | return outputs |
| 59 | |
| 60 | |
| 61 | def hf_login(): |