| 31 | |
| 32 | |
| 33 | class H4ArgumentParser(HfArgumentParser): |
| 34 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: |
| 35 | """ |
| 36 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. |
| 37 | |
| 38 | Args: |
| 39 | yaml_arg (`str`): |
| 40 | The path to the config file used |
| 41 | other_args (`List[str]`, *optional`): |
| 42 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. |
| 43 | |
| 44 | Returns: |
| 45 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line |
| 46 | """ |
| 47 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) |
| 48 | |
| 49 | outputs = [] |
| 50 | # strip other args list into dict of key-value pairs |
| 51 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} |
| 52 | used_args = {} |
| 53 | |
| 54 | # overwrite the default/loaded value with the value provided to the command line |
| 55 | # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 |
| 56 | for data_yaml, data_class in zip(arg_list, self.dataclass_types): |
| 57 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} |
| 58 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} |
| 59 | for arg, val in other_args.items(): |
| 60 | # add only if in keys |
| 61 | |
| 62 | if arg in keys: |
| 63 | base_type = data_yaml.__dataclass_fields__[arg].type |
| 64 | inputs[arg] = val |
| 65 | |
| 66 | # cast type for ints, floats (default to strings) |
| 67 | if base_type in [int, float]: |
| 68 | inputs[arg] = base_type(val) |
| 69 | |
| 70 | if base_type == List[str]: |
| 71 | inputs[arg] = [str(v) for v in val.split(",")] |
| 72 | |
| 73 | # bool of a non-empty string is True, so we manually check for bools |
| 74 | if base_type is bool: |
| 75 | if val in ["true", "True"]: |
| 76 | inputs[arg] = True |
| 77 | else: |
| 78 | inputs[arg] = False |
| 79 | |
| 80 | # add to used-args so we can check if double add |
| 81 | if arg not in used_args: |
| 82 | used_args[arg] = val |
| 83 | else: |
| 84 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") |
| 85 | |
| 86 | obj = data_class(**inputs) |
| 87 | outputs.append(obj) |
| 88 | |
| 89 | return outputs |
| 90 | |