Compute the length of the visual trace pred_tracks: e.g., [1, 77, 225, 2] pred_visibility: e.g., [1, 77, 225] image_size: e.g., [720, 1280]
(self, pred_tracks, pred_visibility, image_size)
| 39 | return video, pred_tracks, pred_visibility |
| 40 | |
| 41 | def visual_trace_length(self, pred_tracks, pred_visibility, image_size): |
| 42 | """ |
| 43 | Compute the length of the visual trace |
| 44 | pred_tracks: e.g., [1, 77, 225, 2] |
| 45 | pred_visibility: e.g., [1, 77, 225] |
| 46 | image_size: e.g., [720, 1280] |
| 47 | """ |
| 48 | pred_tracks_normalized = pred_tracks / torch.tensor(image_size).float()[None, None, None, :].to(pred_tracks.device) |
| 49 | pred_visiblity_float = pred_visibility[:, 1:].float().to(pred_tracks.device) |
| 50 | consecutive_displacement = torch.norm(pred_tracks_normalized[:, 1:] - pred_tracks_normalized[:, :-1], dim=3) |
| 51 | # average_displacement = (consecutive_displacement * pred_visiblity_float).sum(1) / (1e-5 + pred_visiblity_float.sum(1)) |
| 52 | average_displacement = consecutive_displacement.mean(1) |
| 53 | return average_displacement |
| 54 | |
| 55 | |
| 56 | def visualize(self, video, pred_tracks, pred_visibility, filename="visual_trace.mp4", mode="ranbow"): |
no outgoing calls
no test coverage detected