Removes points from the 2D trajectory that are closer than min_distance apart. Parameters: trajectory (torch.Tensor): A tensor of shape (N, 2) representing N points in 2D space. min_distance (float): The minimum distance threshold for points to be retained.
(self, trajectory, min_distance=2)
| 120 | return sampled_ids |
| 121 | |
| 122 | def remove_close_points_tensor(self, trajectory, min_distance=2): |
| 123 | """ |
| 124 | Removes points from the 2D trajectory that are closer than min_distance apart. |
| 125 | |
| 126 | Parameters: |
| 127 | trajectory (torch.Tensor): A tensor of shape (N, 2) representing N points in 2D space. |
| 128 | min_distance (float): The minimum distance threshold for points to be retained. |
| 129 | |
| 130 | Returns: |
| 131 | torch.Tensor: A filtered tensor of points where consecutive points are at least min_distance apart. |
| 132 | """ |
| 133 | # Start with the first point |
| 134 | filtered_trajectory = [trajectory[0]] |
| 135 | |
| 136 | # Iterate through the points |
| 137 | for i in range(1, trajectory.size(0)): |
| 138 | prev_point = filtered_trajectory[-1] |
| 139 | curr_point = trajectory[i] |
| 140 | |
| 141 | # Calculate the Euclidean distance between the previous point and the current point |
| 142 | distance = torch.norm(curr_point - prev_point) |
| 143 | |
| 144 | # Keep the point if it's at least min_distance apart from the previous one |
| 145 | if distance >= min_distance: |
| 146 | filtered_trajectory.append(curr_point) |
| 147 | |
| 148 | # Convert the filtered list back to a tensor |
| 149 | return torch.stack(filtered_trajectory) |
no outgoing calls
no test coverage detected