Get model path from model directory. `safetensors` or `pytorch binary` is supported. Args: model_dir: model directory. name: model file name without suffix. Returns: Full model path.
(
model_dir: Union[str, Path],
name: Optional[str] = None,
)
| 200 | |
| 201 | |
| 202 | def get_model_path( |
| 203 | model_dir: Union[str, Path], |
| 204 | name: Optional[str] = None, |
| 205 | ) -> Optional[str]: |
| 206 | """ Get model path from model directory. |
| 207 | |
| 208 | `safetensors` or `pytorch binary` is supported. |
| 209 | Args: |
| 210 | model_dir: model directory. |
| 211 | name: model file name without suffix. |
| 212 | Returns: |
| 213 | Full model path. |
| 214 | """ |
| 215 | model_dir = Path(model_dir) |
| 216 | if name is not None: |
| 217 | if (model_dir / f"{name}.safetensors").exists(): |
| 218 | return str(model_dir / f"{name}.safetensors") |
| 219 | elif (model_dir / f"{name}.bin").exists(): |
| 220 | return str(model_dir / f"{name}.bin") |
| 221 | else: |
| 222 | return None |
| 223 | else: |
| 224 | model_files = list(model_dir.glob('*.safetensors')) |
| 225 | if len(model_files) > 0: |
| 226 | assert len( |
| 227 | model_files |
| 228 | ) == 1, f"find multiple safetensors files in {model_dir}, please specify one" |
| 229 | return str(model_files[0]) |
| 230 | model_files = list(model_dir.glob('*.bin')) |
| 231 | if len(model_files) > 0: |
| 232 | assert len( |
| 233 | model_files |
| 234 | ) == 1, f"find multiple bin files in {model_dir}, please specify one" |
| 235 | return str(model_files[0]) |
| 236 | return None |
| 237 | |
| 238 | |
| 239 | def retrieved_layer_index_from_name(name: str) -> Optional[int]: |
no outgoing calls
no test coverage detected