(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False,
args={})
| 122 | return output |
| 123 | |
| 124 | def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False, verbose=False, is_densecap=False, |
| 125 | args={}): |
| 126 | # segment with prompt |
| 127 | print("CA prompt: ", prompt, "CA controls", controls) |
| 128 | is_seg_everything = 'everything' in prompt['prompt_type'] |
| 129 | |
| 130 | args['seg_crop_mode'] = args.get('seg_crop_mode', self.args.seg_crop_mode) |
| 131 | args['clip_filter'] = args.get('clip_filter', self.args.clip_filter) |
| 132 | args['disable_regular_box'] = args.get('disable_regular_box', self.args.disable_regular_box) |
| 133 | args['context_captions'] = args.get('context_captions', self.args.context_captions) |
| 134 | args['enable_reduce_tokens'] = args.get('enable_reduce_tokens', self.args.enable_reduce_tokens) |
| 135 | args['enable_morphologyex'] = args.get('enable_morphologyex', self.args.enable_morphologyex) |
| 136 | args['topN'] = args.get('topN', 10) if is_seg_everything else 1 |
| 137 | args['min_mask_area'] = args.get('min_mask_area', 0) |
| 138 | |
| 139 | if not is_densecap: |
| 140 | seg_results = self.segmenter.inference(image, prompt) |
| 141 | else: |
| 142 | seg_results = self.segmenter_densecap.inference(image, prompt) |
| 143 | |
| 144 | seg_masks, seg_bbox, seg_area = seg_results if is_seg_everything else (seg_results, None, None) |
| 145 | |
| 146 | if args['topN'] > 1: # sort by area |
| 147 | samples = list(zip(*[seg_masks, seg_bbox, seg_area])) |
| 148 | # top_samples = sorted(samples, key=lambda x: x[2], reverse=True) |
| 149 | # seg_masks, seg_bbox, seg_area = list(zip(*top_samples)) |
| 150 | samples = list(filter(lambda x: x[2] > args['min_mask_area'], samples)) |
| 151 | samples = samples[:args['topN']] |
| 152 | seg_masks, seg_bbox, seg_area = list(zip(*samples)) |
| 153 | |
| 154 | out_list = [] |
| 155 | for i, seg_mask in enumerate(seg_masks): |
| 156 | if args['enable_morphologyex']: |
| 157 | seg_mask = 255 * seg_mask.astype(np.uint8) |
| 158 | seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1) |
| 159 | seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8)) |
| 160 | seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8)) |
| 161 | seg_mask = seg_mask[:, :, 0] > 0 |
| 162 | |
| 163 | seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.) |
| 164 | mask_save_path = None |
| 165 | |
| 166 | if verbose: |
| 167 | mask_save_path = f'result/mask_{time.time()}.png' |
| 168 | if not os.path.exists(os.path.dirname(mask_save_path)): |
| 169 | os.makedirs(os.path.dirname(mask_save_path)) |
| 170 | |
| 171 | if seg_mask_img.mode != 'RGB': |
| 172 | seg_mask_img = seg_mask_img.convert('RGB') |
| 173 | seg_mask_img.save(mask_save_path) |
| 174 | print('seg_mask path: ', mask_save_path) |
| 175 | print("seg_mask.shape: ", seg_mask.shape) |
| 176 | |
| 177 | # captioning with mask |
| 178 | if args['enable_reduce_tokens']: |
| 179 | result = self.captioner.inference_with_reduced_tokens(image, seg_mask, |
| 180 | crop_mode=args['seg_crop_mode'], |
| 181 | filter=args['clip_filter'], |
no test coverage detected