given ground truth bbox and label, build graph for validation
(
img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False
)
| 165 | |
| 166 | |
| 167 | def build_graph_validate_gt_obj( |
| 168 | img, gt_ids, bbox, spatial_feat, bbox_improvement=True, overlap=False |
| 169 | ): |
| 170 | """given ground truth bbox and label, build graph for validation""" |
| 171 | n_batch = img.shape[0] |
| 172 | img_size = img.shape[2:4] |
| 173 | bbox[:, :, 0] /= img_size[1] |
| 174 | bbox[:, :, 1] /= img_size[0] |
| 175 | bbox[:, :, 2] /= img_size[1] |
| 176 | bbox[:, :, 3] /= img_size[0] |
| 177 | ctx = img.context |
| 178 | |
| 179 | g_batch = [] |
| 180 | for btc in range(n_batch): |
| 181 | inds = np.where(bbox[btc].sum(1).asnumpy() > 0)[0].tolist() |
| 182 | if len(inds) == 0: |
| 183 | continue |
| 184 | n_nodes = len(inds) |
| 185 | g_pred = dgl.DGLGraph() |
| 186 | g_pred.add_nodes( |
| 187 | n_nodes, |
| 188 | { |
| 189 | "pred_bbox": bbox[btc, inds], |
| 190 | "node_feat": spatial_feat[btc, inds], |
| 191 | "node_class_pred": gt_ids[btc, inds, 0], |
| 192 | "node_class_logit": nd.zeros_like( |
| 193 | gt_ids[btc, inds, 0], ctx=ctx |
| 194 | ), |
| 195 | }, |
| 196 | ) |
| 197 | |
| 198 | edge_list = [] |
| 199 | for i in range(n_nodes - 1): |
| 200 | for j in range(i + 1, n_nodes): |
| 201 | edge_list.append((i, j)) |
| 202 | src, dst = tuple(zip(*edge_list)) |
| 203 | g_pred.add_edges(src, dst) |
| 204 | g_pred.add_edges(dst, src) |
| 205 | |
| 206 | n_nodes = g_pred.number_of_nodes() |
| 207 | n_edges = g_pred.number_of_edges() |
| 208 | if bbox_improvement: |
| 209 | g_pred.ndata["pred_bbox"] = bbox_improve(g_pred.ndata["pred_bbox"]) |
| 210 | g_pred.edata["rel_bbox"] = extract_edge_bbox(g_pred) |
| 211 | g_pred.edata["batch_id"] = nd.zeros((n_edges, 1), ctx=ctx) + btc |
| 212 | |
| 213 | g_batch.append(g_pred) |
| 214 | |
| 215 | if len(g_batch) == 0: |
| 216 | return None |
| 217 | if len(g_batch) > 1: |
| 218 | return dgl.batch(g_batch) |
| 219 | return g_batch[0] |
| 220 | |
| 221 | |
| 222 | def build_graph_validate_gt_bbox( |
no test coverage detected