(self, num_samples)
| 89 | track_losses.items()}) |
| 90 | |
| 91 | def get_num_boxes(self, num_samples): |
| 92 | num_boxes = torch.as_tensor(num_samples, dtype=torch.float, device=self.sample_device) |
| 93 | if is_dist_avail_and_initialized(): |
| 94 | torch.distributed.all_reduce(num_boxes) |
| 95 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() |
| 96 | return num_boxes |
| 97 | |
| 98 | def get_loss(self, loss, outputs, gt_instances, indices, num_boxes, **kwargs): |
| 99 | loss_map = { |
no test coverage detected