(self, image, caption, with_logits=True)
| 1067 | return model |
| 1068 | |
| 1069 | def get_grounding_boxes(self, image, caption, with_logits=True): |
| 1070 | caption = caption.lower() |
| 1071 | caption = caption.strip() |
| 1072 | if not caption.endswith("."): |
| 1073 | caption = caption + "." |
| 1074 | image = image.to(self.device) |
| 1075 | with torch.no_grad(): |
| 1076 | outputs = self.grounding(image[None], captions=[caption]) |
| 1077 | logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) |
| 1078 | boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) |
| 1079 | logits.shape[0] |
| 1080 | |
| 1081 | # filter output |
| 1082 | logits_filt = logits.clone() |
| 1083 | boxes_filt = boxes.clone() |
| 1084 | filt_mask = logits_filt.max(dim=1)[0] > self.box_threshold |
| 1085 | logits_filt = logits_filt[filt_mask] # num_filt, 256 |
| 1086 | boxes_filt = boxes_filt[filt_mask] # num_filt, 4 |
| 1087 | logits_filt.shape[0] |
| 1088 | |
| 1089 | # get phrase |
| 1090 | tokenlizer = self.grounding.tokenizer |
| 1091 | tokenized = tokenlizer(caption) |
| 1092 | # build pred |
| 1093 | pred_phrases = [] |
| 1094 | for logit, box in zip(logits_filt, boxes_filt): |
| 1095 | pred_phrase = get_phrases_from_posmap(logit > self.text_threshold, tokenized, tokenlizer) |
| 1096 | if with_logits: |
| 1097 | pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") |
| 1098 | else: |
| 1099 | pred_phrases.append(pred_phrase) |
| 1100 | |
| 1101 | return boxes_filt, pred_phrases |
| 1102 | |
| 1103 | def plot_boxes_to_image(self, image_pil, tgt): |
| 1104 | H, W = tgt["size"] |
no outgoing calls
no test coverage detected