MCPcopy
hub / github.com/hustvl/Vim / losses

Method losses

det/detectron2/modeling/meta_arch/fcos.py:193–238  ·  view source on GitHub ↗

This method is almost identical to :meth:`RetinaNet.losses`, with an extra "loss_centerness" in the returned dict.

(
        self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
    )

Source from the content-addressed store, hash-verified

191 return gt_labels, matched_gt_boxes
192
193 def losses(
194 self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness
195 ):
196 """
197 This method is almost identical to :meth:`RetinaNet.losses`, with an extra
198 "loss_centerness" in the returned dict.
199 """
200 num_images = len(gt_labels)
201 gt_labels = torch.stack(gt_labels) # (M, R)
202
203 pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
204 num_pos_anchors = pos_mask.sum().item()
205 get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images)
206 normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300)
207
208 # classification and regression loss
209 gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[
210 :, :, :-1
211 ] # no loss for the last (background) class
212 loss_cls = sigmoid_focal_loss_jit(
213 torch.cat(pred_logits, dim=1),
214 gt_labels_target.to(pred_logits[0].dtype),
215 alpha=self.focal_loss_alpha,
216 gamma=self.focal_loss_gamma,
217 reduction="sum",
218 )
219
220 loss_box_reg = _dense_box_regression_loss(
221 anchors,
222 self.box2box_transform,
223 pred_anchor_deltas,
224 gt_boxes,
225 pos_mask,
226 box_reg_loss_type="giou",
227 )
228
229 ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # (M, R)
230 pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # (M, R)
231 ctrness_loss = F.binary_cross_entropy_with_logits(
232 pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum"
233 )
234 return {
235 "loss_fcos_cls": loss_cls / normalizer,
236 "loss_fcos_loc": loss_box_reg / normalizer,
237 "loss_fcos_ctr": ctrness_loss / normalizer,
238 }
239
240 def compute_ctrness_targets(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]):
241 anchors = Boxes.cat(anchors).tensor # Rx4

Callers 3

forward_trainingMethod · 0.95

Calls 7

get_event_storageFunction · 0.90
put_scalarMethod · 0.80
_ema_updateMethod · 0.80
catMethod · 0.45
toMethod · 0.45

Tested by 2