(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor])
| 238 | } |
| 239 | |
| 240 | def compute_ctrness_targets(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]): |
| 241 | anchors = Boxes.cat(anchors).tensor # Rx4 |
| 242 | reg_targets = [self.box2box_transform.get_deltas(anchors, m) for m in gt_boxes] |
| 243 | reg_targets = torch.stack(reg_targets, dim=0) # NxRx4 |
| 244 | if len(reg_targets) == 0: |
| 245 | return reg_targets.new_zeros(len(reg_targets)) |
| 246 | left_right = reg_targets[:, :, [0, 2]] |
| 247 | top_bottom = reg_targets[:, :, [1, 3]] |
| 248 | ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( |
| 249 | top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0] |
| 250 | ) |
| 251 | return torch.sqrt(ctrness) |
| 252 | |
| 253 | def forward_inference( |
| 254 | self, |
no test coverage detected