| 107 | |
| 108 | |
| 109 | def refit(engine_dir: str, checkpoint_dir: str, engine_config: EngineConfig, |
| 110 | output_dir: str, fixed_weights_names: list): |
| 111 | refit_engine_dir = output_dir |
| 112 | os.makedirs(refit_engine_dir, exist_ok=True) |
| 113 | shutil.copyfile(os.path.join(engine_dir, 'config.json'), |
| 114 | os.path.join(refit_engine_dir, 'config.json')) |
| 115 | engine_paths = list(Path(engine_dir).glob('*.engine')) |
| 116 | for path in engine_paths: |
| 117 | refit_engine(path, refit_engine_dir, checkpoint_dir, engine_config, |
| 118 | fixed_weights_names) |
| 119 | |
| 120 | |
| 121 | def main(): |