Get the format of the model.
(model_dir: str,
trust_remote_code: bool = False)
| 3378 | |
| 3379 | |
| 3380 | def get_model_format(model_dir: str, |
| 3381 | trust_remote_code: bool = False) -> _ModelFormatKind: |
| 3382 | ''' Get the format of the model. ''' |
| 3383 | if not (Path(model_dir) / 'config.json').exists(): |
| 3384 | raise ValueError( |
| 3385 | f"Failed to infer model format because no config.json exists in {model_dir}" |
| 3386 | ) |
| 3387 | |
| 3388 | with open(Path(model_dir) / 'config.json') as f: |
| 3389 | config = json.load(f) |
| 3390 | |
| 3391 | try: |
| 3392 | if 'pretrained_config' in config and 'build_config' in config: |
| 3393 | model_format = _ModelFormatKind.TLLM_ENGINE |
| 3394 | EngineConfig.from_json_file(Path(model_dir) / 'config.json') |
| 3395 | elif 'architecture' in config and 'dtype' in config: |
| 3396 | model_format = _ModelFormatKind.TLLM_CKPT |
| 3397 | PretrainedConfig.from_checkpoint(model_dir) |
| 3398 | else: |
| 3399 | model_format = _ModelFormatKind.HF |
| 3400 | AutoConfig.from_hugging_face(model_dir, |
| 3401 | trust_remote_code=trust_remote_code) |
| 3402 | except Exception as e: |
| 3403 | raise ValueError( |
| 3404 | f"Inferred model format {model_format}, but failed to load config.json: {e}" |
| 3405 | ) |
| 3406 | else: |
| 3407 | return model_format |
| 3408 | |
| 3409 | |
| 3410 | LlmArgs = TorchLlmArgs |