Construct conversations for spatial prediction
(self, item, image, visual_traces, frame_pos, pos_traces_to_mark=None, neg_traces_to_mark=None, normalize=True)
| 132 | |
| 133 | |
| 134 | def _construct_conv_som(self, item, image, visual_traces, frame_pos, pos_traces_to_mark=None, neg_traces_to_mark=None, normalize=True): |
| 135 | """ |
| 136 | Construct conversations for spatial prediction |
| 137 | """ |
| 138 | |
| 139 | if pos_traces_to_mark is None or neg_traces_to_mark is None: |
| 140 | pred_tracks = visual_traces['pred_tracks'] |
| 141 | pred_visibility = visual_traces['pred_visibility'] |
| 142 | # randomly sample pos_tracks and neg_tracks |
| 143 | num_clusters_pos = torch.randint(2, 6, (1,)).item() |
| 144 | num_clusters_neg = torch.randint(6, 15, (1,)).item() |
| 145 | pos_tracks = pred_tracks[:,frame_pos:,torch.randint(0, pred_tracks.size(2), (num_clusters_pos,))] |
| 146 | neg_tracks = pred_tracks[:,frame_pos:,torch.randint(0, pred_tracks.size(2), (num_clusters_neg,))] |
| 147 | |
| 148 | image, pos_traces_to_mark, neg_traces_to_mark, pos_mark_ids, neg_mark_ids, all_idx = \ |
| 149 | som_prompting(image, pos_tracks, neg_tracks, draw_som_positive=True, draw_som_negative=True) |
| 150 | |
| 151 | conv_user = ( |
| 152 | f"{self.image_placeholder}\nThe image is split into {self.spatial_quant_size}x{self.spatial_quant_size} grids, and labeled with numeric marks.\n" |
| 153 | f"Please locate all the numerical marks in the image.\n" |
| 154 | ) |
| 155 | |
| 156 | # combine pos_traces_to_mark and neg_traces_to_mark |
| 157 | pos_traces_to_mark.update(neg_traces_to_mark) |
| 158 | # sort pos_traces_to_mark by the key |
| 159 | pos_traces_to_mark = dict(sorted(pos_traces_to_mark.items())) |
| 160 | |
| 161 | marks_pos = [] |
| 162 | for key, val in pos_traces_to_mark.items(): |
| 163 | trace = val[0] |
| 164 | if normalize: |
| 165 | x = int(self.spatial_quant_size * trace[0, 0] / image.size[0]) |
| 166 | y = int(self.spatial_quant_size * trace[0, 1] / image.size[1]) |
| 167 | else: |
| 168 | x = int(trace[0, 0]) |
| 169 | y = int(trace[0, 1]) |
| 170 | val_str = f"[{x},{y}]" |
| 171 | marks_pos.append(f'Mark {key} at {val_str}') |
| 172 | |
| 173 | conv_gpt = ". ".join(marks_pos) + '\n' |
| 174 | return conv_user, conv_gpt, image |
| 175 | |
| 176 | def _construct_conv_tom(self, item, video_path, visual_traces): |
| 177 | """ |
no test coverage detected