This method is almost identical to :meth:`RetinaNet.losses`, with an extra "loss_centerness" in the returned dict.
(
self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
)
| 191 | return gt_labels, matched_gt_boxes |
| 192 | |
| 193 | def losses( |
| 194 | self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness |
| 195 | ): |
| 196 | """ |
| 197 | This method is almost identical to :meth:`RetinaNet.losses`, with an extra |
| 198 | "loss_centerness" in the returned dict. |
| 199 | """ |
| 200 | num_images = len(gt_labels) |
| 201 | gt_labels = torch.stack(gt_labels) # (M, R) |
| 202 | |
| 203 | pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) |
| 204 | num_pos_anchors = pos_mask.sum().item() |
| 205 | get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) |
| 206 | normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300) |
| 207 | |
| 208 | # classification and regression loss |
| 209 | gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[ |
| 210 | :, :, :-1 |
| 211 | ] # no loss for the last (background) class |
| 212 | loss_cls = sigmoid_focal_loss_jit( |
| 213 | torch.cat(pred_logits, dim=1), |
| 214 | gt_labels_target.to(pred_logits[0].dtype), |
| 215 | alpha=self.focal_loss_alpha, |
| 216 | gamma=self.focal_loss_gamma, |
| 217 | reduction="sum", |
| 218 | ) |
| 219 | |
| 220 | loss_box_reg = _dense_box_regression_loss( |
| 221 | anchors, |
| 222 | self.box2box_transform, |
| 223 | pred_anchor_deltas, |
| 224 | gt_boxes, |
| 225 | pos_mask, |
| 226 | box_reg_loss_type="giou", |
| 227 | ) |
| 228 | |
| 229 | ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # (M, R) |
| 230 | pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # (M, R) |
| 231 | ctrness_loss = F.binary_cross_entropy_with_logits( |
| 232 | pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum" |
| 233 | ) |
| 234 | return { |
| 235 | "loss_fcos_cls": loss_cls / normalizer, |
| 236 | "loss_fcos_loc": loss_box_reg / normalizer, |
| 237 | "loss_fcos_ctr": ctrness_loss / normalizer, |
| 238 | } |
| 239 | |
| 240 | def compute_ctrness_targets(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]): |
| 241 | anchors = Boxes.cat(anchors).tensor # Rx4 |