(config: Union[str, dict])
| 17 | |
| 18 | |
| 19 | def load_ds_config(config: Union[str, dict]) -> dict: |
| 20 | if isinstance(config, dict): |
| 21 | return config |
| 22 | if isinstance(config, str): |
| 23 | if os.path.exists(config): |
| 24 | return hjson.load(open(config, "r"), object_pairs_hook=dict_raise_error_on_duplicate_keys) |
| 25 | try: |
| 26 | config_decoded = base64.urlsafe_b64decode(config).decode('utf-8') |
| 27 | return hjson.loads(config_decoded) |
| 28 | except (UnicodeDecodeError, AttributeError, ValueError) as exc: |
| 29 | raise ValueError( |
| 30 | f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. " |
| 31 | f"Received: {config}") from exc |
| 32 | raise ValueError(f"Expected a string path to an existing deepspeed config, or a dictionary or a valid base64. " |
| 33 | f"Received: {config}") |
| 34 | |
| 35 | |
| 36 | def record_tp_model_init_args(tp_size, dtype, tp_group, dist_module): |
no test coverage detected
searching dependent graphs…