(path, dst_path=None, interact=True, overwrite=False)
| 309 | writer.add_tensor(new_name, data, raw_dtype=data_qtype) |
| 310 | |
| 311 | def convert_file(path, dst_path=None, interact=True, overwrite=False): |
| 312 | # load & run model detection logic |
| 313 | state_dict = load_state_dict(path) |
| 314 | model_arch = detect_arch(state_dict) |
| 315 | logging.info(f"* Architecture detected from input: {model_arch.arch}") |
| 316 | |
| 317 | # detect & set dtype for output file |
| 318 | dtypes = [x.dtype for x in state_dict.values()] |
| 319 | dtypes = {x:dtypes.count(x) for x in set(dtypes)} |
| 320 | main_dtype = max(dtypes, key=dtypes.get) |
| 321 | |
| 322 | if main_dtype == torch.bfloat16: |
| 323 | ftype_name = "BF16" |
| 324 | ftype_gguf = gguf.LlamaFileType.MOSTLY_BF16 |
| 325 | # elif main_dtype == torch.float32: |
| 326 | # ftype_name = "F32" |
| 327 | # ftype_gguf = None |
| 328 | else: |
| 329 | ftype_name = "F16" |
| 330 | ftype_gguf = gguf.LlamaFileType.MOSTLY_F16 |
| 331 | |
| 332 | if dst_path is None: |
| 333 | dst_path = f"{os.path.splitext(path)[0]}-{ftype_name}.gguf" |
| 334 | elif "{ftype}" in dst_path: # lcpp logic |
| 335 | dst_path = dst_path.replace("{ftype}", ftype_name) |
| 336 | |
| 337 | if os.path.isfile(dst_path) and not overwrite: |
| 338 | if interact: |
| 339 | input("Output exists enter to continue or ctrl+c to abort!") |
| 340 | else: |
| 341 | raise OSError("Output exists and overwriting is disabled!") |
| 342 | |
| 343 | # handle actual file |
| 344 | writer = gguf.GGUFWriter(path=None, arch=model_arch.arch) |
| 345 | writer.add_quantization_version(gguf.GGML_QUANT_VERSION) |
| 346 | if ftype_gguf is not None: |
| 347 | writer.add_file_type(ftype_gguf) |
| 348 | |
| 349 | handle_tensors(writer, state_dict, model_arch) |
| 350 | writer.write_header_to_file(path=dst_path) |
| 351 | writer.write_kv_data_to_file() |
| 352 | writer.write_tensors_to_file(progress=True) |
| 353 | writer.close() |
| 354 | |
| 355 | fix = f"./fix_5d_tensors_{model_arch.arch}.safetensors" |
| 356 | if os.path.isfile(fix): |
| 357 | logging.warning(f"\n### Warning! Fix file found at '{fix}'") |
| 358 | logging.warning(" you most likely need to run 'fix_5d_tensors.py' after quantization.") |
| 359 | |
| 360 | return dst_path, model_arch |
| 361 | |
| 362 | if __name__ == "__main__": |
| 363 | args = parse_args() |
no test coverage detected