| 125 | self.logger.log(trt.Logger.WARNING, f'output "{out.name}" with shape: {out.shape} dtype: {out.dtype}') |
| 126 | |
| 127 | def build_from_api( |
| 128 | self, |
| 129 | fp16: bool = True, |
| 130 | input_shape: list | tuple = (1, 3, 640, 640), |
| 131 | iou_thres: float = 0.65, |
| 132 | conf_thres: float = 0.25, |
| 133 | topk: int = 100, |
| 134 | ) -> None: |
| 135 | if self.seg: |
| 136 | raise ValueError("the TensorRT API builder does not support segmentation models") |
| 137 | from .api import SPPF, C2f, Conv, Detect, get_depth, get_width |
| 138 | |
| 139 | with open(self.checkpoint, "rb") as f: |
| 140 | state_dict = pickle.load(f) |
| 141 | mapping = {0.25: 1024, 0.5: 1024, 0.75: 768, 1.0: 512, 1.25: 512} |
| 142 | |
| 143 | GW = state_dict["GW"] |
| 144 | GD = state_dict["GD"] |
| 145 | width_64 = get_width(64, GW) |
| 146 | width_128 = get_width(128, GW) |
| 147 | width_256 = get_width(256, GW) |
| 148 | width_512 = get_width(512, GW) |
| 149 | width_1024 = get_width(mapping[GW], GW) |
| 150 | depth_3 = get_depth(3, GD) |
| 151 | depth_6 = get_depth(6, GD) |
| 152 | strides = state_dict["strides"] |
| 153 | reg_max = state_dict["reg_max"] |
| 154 | images = self.network.add_input(name="images", dtype=trt.float32, shape=trt.Dims4(input_shape)) |
| 155 | assert images, "Add input failed" |
| 156 | |
| 157 | Conv_0 = Conv(self.network, state_dict, images, width_64, 3, 2, 1, "Conv.0") |
| 158 | Conv_1 = Conv(self.network, state_dict, Conv_0.get_output(0), width_128, 3, 2, 1, "Conv.1") |
| 159 | C2f_2 = C2f(self.network, state_dict, Conv_1.get_output(0), width_128, depth_3, True, 1, 0.5, "C2f.2") |
| 160 | Conv_3 = Conv(self.network, state_dict, C2f_2.get_output(0), width_256, 3, 2, 1, "Conv.3") |
| 161 | C2f_4 = C2f(self.network, state_dict, Conv_3.get_output(0), width_256, depth_6, True, 1, 0.5, "C2f.4") |
| 162 | Conv_5 = Conv(self.network, state_dict, C2f_4.get_output(0), width_512, 3, 2, 1, "Conv.5") |
| 163 | C2f_6 = C2f(self.network, state_dict, Conv_5.get_output(0), width_512, depth_6, True, 1, 0.5, "C2f.6") |
| 164 | Conv_7 = Conv(self.network, state_dict, C2f_6.get_output(0), width_1024, 3, 2, 1, "Conv.7") |
| 165 | C2f_8 = C2f(self.network, state_dict, Conv_7.get_output(0), width_1024, depth_3, True, 1, 0.5, "C2f.8") |
| 166 | SPPF_9 = SPPF(self.network, state_dict, C2f_8.get_output(0), width_1024, width_1024, 5, "SPPF.9") |
| 167 | Upsample_10 = self.network.add_resize(SPPF_9.get_output(0)) |
| 168 | assert Upsample_10, "Add Upsample_10 failed" |
| 169 | Upsample_10.resize_mode = trt.ResizeMode.NEAREST |
| 170 | Upsample_10.shape = Upsample_10.get_output(0).shape[:2] + C2f_6.get_output(0).shape[2:] |
| 171 | input_tensors11 = [Upsample_10.get_output(0), C2f_6.get_output(0)] |
| 172 | Cat_11 = self.network.add_concatenation(input_tensors11) |
| 173 | C2f_12 = C2f(self.network, state_dict, Cat_11.get_output(0), width_512, depth_3, False, 1, 0.5, "C2f.12") |
| 174 | Upsample13 = self.network.add_resize(C2f_12.get_output(0)) |
| 175 | assert Upsample13, "Add Upsample13 failed" |
| 176 | Upsample13.resize_mode = trt.ResizeMode.NEAREST |
| 177 | Upsample13.shape = Upsample13.get_output(0).shape[:2] + C2f_4.get_output(0).shape[2:] |
| 178 | input_tensors14 = [Upsample13.get_output(0), C2f_4.get_output(0)] |
| 179 | Cat_14 = self.network.add_concatenation(input_tensors14) |
| 180 | C2f_15 = C2f(self.network, state_dict, Cat_14.get_output(0), width_256, depth_3, False, 1, 0.5, "C2f.15") |
| 181 | Conv_16 = Conv(self.network, state_dict, C2f_15.get_output(0), width_256, 3, 2, 1, "Conv.16") |
| 182 | input_tensors17 = [Conv_16.get_output(0), C2f_12.get_output(0)] |
| 183 | Cat_17 = self.network.add_concatenation(input_tensors17) |
| 184 | C2f_18 = C2f(self.network, state_dict, Cat_17.get_output(0), width_512, depth_3, False, 1, 0.5, "C2f.18") |