(self)
| 69 | return os.path.join(TRT_MODEL_DIR, f"{model_name}_weights_map.json") |
| 70 | |
| 71 | def update(self): |
| 72 | trt_engines = [ |
| 73 | trt_file |
| 74 | for trt_file in os.listdir(TRT_MODEL_DIR) |
| 75 | if trt_file.endswith(".trt") |
| 76 | ] |
| 77 | |
| 78 | tmp_all_models = copy.deepcopy(self.all_models) |
| 79 | for cc, base_models in tmp_all_models.items(): |
| 80 | for base_model, models in base_models.items(): |
| 81 | tmp_config_list = {} |
| 82 | for model_config in models: |
| 83 | if model_config["filepath"] not in trt_engines: |
| 84 | info( |
| 85 | "Model config outdated. {} was not found".format( |
| 86 | model_config["filepath"] |
| 87 | ) |
| 88 | ) |
| 89 | continue |
| 90 | tmp_config_list[model_config["filepath"]] = model_config |
| 91 | |
| 92 | tmp_config_list = list(tmp_config_list.values()) |
| 93 | if len(tmp_config_list) == 0: |
| 94 | self.all_models[cc].pop(base_model) |
| 95 | else: |
| 96 | self.all_models[cc][base_model] = models |
| 97 | |
| 98 | self.write_json() |
| 99 | |
| 100 | def __del__(self): |
| 101 | self.update() |
no test coverage detected