(args, model, criterion, postprocessors, data_loader, base_ds, device, output_dir, tracker=None,
phase='train', det_val=False)
| 196 | |
| 197 | @torch.no_grad() |
| 198 | def evaluate_track(args, model, criterion, postprocessors, data_loader, base_ds, device, output_dir, tracker=None, |
| 199 | phase='train', det_val=False): |
| 200 | model.eval() |
| 201 | criterion.eval() |
| 202 | |
| 203 | metric_logger = utils.MetricLogger(delimiter=" ") |
| 204 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) |
| 205 | header = 'Test:' |
| 206 | |
| 207 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) |
| 208 | coco_evaluator = CocoEvaluator(base_ds, iou_types) |
| 209 | # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] |
| 210 | |
| 211 | res_tracks = dict() |
| 212 | pre_embed = None |
| 213 | for samples, targets in metric_logger.log_every(data_loader, 50, header): |
| 214 | # pre process for track. |
| 215 | if tracker is not None: |
| 216 | frame_id = targets[0].get("frame_id", None) |
| 217 | assert frame_id is not None |
| 218 | frame_id = frame_id.item() |
| 219 | if frame_id == 1: |
| 220 | tracker = BYTETracker(args) |
| 221 | pre_embed = None |
| 222 | |
| 223 | samples = samples.to(device) |
| 224 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] |
| 225 | |
| 226 | if det_val: |
| 227 | outputs = model(samples) |
| 228 | else: |
| 229 | outputs, pre_embed = model(samples, pre_embed) |
| 230 | loss_dict = criterion(outputs, targets) |
| 231 | weight_dict = criterion.weight_dict |
| 232 | |
| 233 | # reduce losses over all GPUs for logging purposes |
| 234 | loss_dict_reduced = utils.reduce_dict(loss_dict) |
| 235 | loss_dict_reduced_scaled = {k: v * weight_dict[k] |
| 236 | for k, v in loss_dict_reduced.items() if k in weight_dict} |
| 237 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v |
| 238 | for k, v in loss_dict_reduced.items()} |
| 239 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), |
| 240 | **loss_dict_reduced_scaled, |
| 241 | **loss_dict_reduced_unscaled) |
| 242 | metric_logger.update(class_error=loss_dict_reduced['class_error']) |
| 243 | |
| 244 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) |
| 245 | results = postprocessors['bbox'](outputs, orig_target_sizes) |
| 246 | |
| 247 | if 'segm' in postprocessors.keys(): |
| 248 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) |
| 249 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) |
| 250 | res = {target['image_id'].item(): output for target, output in zip(targets, results)} |
| 251 | |
| 252 | # post process for track. |
| 253 | if tracker is not None: |
| 254 | res_track = tracker.update(results[0]) |
| 255 | res_tracks[targets[0]['image_id'].item()] = res_track |
no test coverage detected