(self, model_name, model_hash, profile, static_shape)
| 49 | return onnx_filename, onnx_path |
| 50 | |
| 51 | def get_trt_path(self, model_name, model_hash, profile, static_shape): |
| 52 | profile_hash = [] |
| 53 | n_profiles = 1 if static_shape else 3 |
| 54 | for k, v in profile.items(): |
| 55 | dim_hash = [] |
| 56 | for i in range(n_profiles): |
| 57 | dim_hash.append("x".join([str(x) for x in v[i]])) |
| 58 | profile_hash.append(k + "=" + "+".join(dim_hash)) |
| 59 | |
| 60 | profile_hash = "-".join(profile_hash) |
| 61 | trt_filename = ( |
| 62 | "_".join([model_name, model_hash, self.cc, profile_hash]) + ".trt" |
| 63 | ) |
| 64 | trt_path = os.path.join(TRT_MODEL_DIR, trt_filename) |
| 65 | |
| 66 | return trt_filename, trt_path |
| 67 | |
| 68 | def get_weights_map_path(self, model_name: str): |
| 69 | return os.path.join(TRT_MODEL_DIR, f"{model_name}_weights_map.json") |
no outgoing calls
no test coverage detected