(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
length, image_embedding, state, click_state, original_size, input_size, text_refiner,
evt: gr.SelectData)
| 186 | |
| 187 | |
| 188 | def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, |
| 189 | length, image_embedding, state, click_state, original_size, input_size, text_refiner, |
| 190 | evt: gr.SelectData): |
| 191 | click_index = evt.index |
| 192 | |
| 193 | if point_prompt == 'Positive': |
| 194 | coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1])) |
| 195 | else: |
| 196 | coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1])) |
| 197 | |
| 198 | prompt = get_click_prompt(coordinate, click_state, click_mode) |
| 199 | input_points = prompt['input_point'] |
| 200 | input_labels = prompt['input_label'] |
| 201 | |
| 202 | controls = {'length': length, |
| 203 | 'sentiment': sentiment, |
| 204 | 'factuality': factuality, |
| 205 | 'language': language} |
| 206 | |
| 207 | model = build_caption_anything_with_models( |
| 208 | args, |
| 209 | api_key="", |
| 210 | captioner=shared_captioner, |
| 211 | sam_model=shared_sam_model, |
| 212 | text_refiner=text_refiner, |
| 213 | session_id=iface.app_id |
| 214 | ) |
| 215 | |
| 216 | model.setup(image_embedding, original_size, input_size, is_image_set=True) |
| 217 | |
| 218 | enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False |
| 219 | out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)[0] |
| 220 | |
| 221 | state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)] |
| 222 | state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))] |
| 223 | wiki = out['generated_captions'].get('wiki', "") |
| 224 | update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode) |
| 225 | text = out['generated_captions']['raw_caption'] |
| 226 | input_mask = np.array(out['mask'].convert('P')) |
| 227 | image_input = mask_painter(np.array(image_input), input_mask) |
| 228 | origin_image_input = image_input |
| 229 | image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask, |
| 230 | input_points=input_points, input_labels=input_labels) |
| 231 | yield state, state, click_state, image_input, wiki |
| 232 | if not args.disable_gpt and model.text_refiner: |
| 233 | refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'], |
| 234 | enable_wiki=enable_wiki) |
| 235 | # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption'] |
| 236 | new_cap = refined_caption['caption'] |
| 237 | wiki = refined_caption['wiki'] |
| 238 | state = state + [(None, f"caption: {new_cap}")] |
| 239 | refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]), |
| 240 | input_mask, |
| 241 | input_points=input_points, input_labels=input_labels) |
| 242 | yield state, state, click_state, refined_image_input, wiki |
| 243 | |
| 244 | |
| 245 | def get_sketch_prompt(mask: PIL.Image.Image, multi_mask=True): |
nothing calls this directly
no test coverage detected