(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
engine_config: EngineConfig, fixed_weights_names: list)
| 23 | |
| 24 | @_is_building |
| 25 | def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str, |
| 26 | engine_config: EngineConfig, fixed_weights_names: list): |
| 27 | # This function loops through all weights in the model and does a textual match between |
| 28 | # checkpoint weight names and engine weight names. |
| 29 | rank = int(ENGINE_RE.fullmatch(os.path.basename(engine_path)).group(1)) |
| 30 | tik = time.time() |
| 31 | with open(engine_path, |
| 32 | "rb") as f, trt.Runtime(logger.trt_logger) as runtime: |
| 33 | engine = runtime.deserialize_cuda_engine(f.read()) |
| 34 | tok = time.time() |
| 35 | t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
| 36 | logger.info(f'Load TRT engine time: {t}') |
| 37 | |
| 38 | refitter = trt.Refitter(engine, logger.trt_logger) |
| 39 | refittable_weights = set(refitter.get_all_weights()) |
| 40 | |
| 41 | # Load model. |
| 42 | tik = time.time() |
| 43 | rank_config = PretrainedConfig.from_dict( |
| 44 | engine_config.pretrained_config.to_dict()) |
| 45 | rank_config.set_rank(rank) |
| 46 | |
| 47 | architecture = rank_config.architecture |
| 48 | assert architecture in MODEL_MAP, \ |
| 49 | f"Unsupported model architecture: {architecture}" |
| 50 | model_cls = MODEL_MAP[architecture] |
| 51 | model = model_cls.from_checkpoint(checkpoint_dir, config=rank_config) |
| 52 | |
| 53 | tok = time.time() |
| 54 | t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
| 55 | logger.info(f'Load checkpoint(s) time: {t}') |
| 56 | |
| 57 | # There are weights preprocess during optimize model. |
| 58 | tik = time.time() |
| 59 | build_config = engine_config.build_config.model_copy(deep=True) |
| 60 | optimize_model_with_config(model, build_config) |
| 61 | tok = time.time() |
| 62 | t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
| 63 | logger.info(f'Preprocess weights time: {t}') |
| 64 | |
| 65 | # Refit engine. |
| 66 | tik = time.time() |
| 67 | refitted_weights = [] |
| 68 | for name, buf in model.named_parameters(): |
| 69 | if name in refittable_weights: |
| 70 | assert buf.is_inited, f"Failed because weight `{name}` is not initialized in model." |
| 71 | weight = buf._value |
| 72 | weights_value = trt.Weights(np_dtype_to_trt(weight.dtype), |
| 73 | weight.ctypes.data, weight.size) |
| 74 | assert refitter.set_named_weights( |
| 75 | name, weights_value), f'Failed to refit weight: `{name}`' |
| 76 | refitted_weights.append(name) |
| 77 | else: |
| 78 | if name not in fixed_weights_names: |
| 79 | logger.warning( |
| 80 | f"model weights `{name}` (shape: {buf._value.shape}) is not refittable, this means that we might not be able to update the engine using fine-tuned checkpoint!" |
| 81 | ) |
| 82 |
no test coverage detected