draw marks on the image
(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False)
| 4 | import matplotlib.pyplot as plt |
| 5 | |
| 6 | def som_prompting(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False): |
| 7 | """ |
| 8 | draw marks on the image |
| 9 | """ |
| 10 | image_size = image.size |
| 11 | draw = ImageDraw.Draw(image) |
| 12 | |
| 13 | def get_text_size(text, image, font): |
| 14 | im = Image.new('RGB', (image.width, image.height)) |
| 15 | draw = ImageDraw.Draw(im) |
| 16 | _, _, width, height = draw.textbbox((0, 0), text=text, font=font) |
| 17 | return width, height |
| 18 | |
| 19 | def expand_bbox(bbox): |
| 20 | x1, y1, x2, y2 = bbox |
| 21 | return [x1-4, y1-4, x2+4, y2+4] |
| 22 | |
| 23 | def draw_marks(draw, points, text_size, id, font_size): |
| 24 | txt = str(id) |
| 25 | draw.ellipse(((points[0]-max(text_size)//2-1, points[1]-max(text_size)//2-1, points[0]+max(text_size)//2+1, points[1]+max(text_size)//2+1)), fill='red') |
| 26 | draw.text((points[0]-text_size[0] // 2, points[1]-text_size[1] // 2-3), txt, fill='white', font=font_size) |
| 27 | |
| 28 | fontsize = 1 |
| 29 | font = ImageFont.truetype("data/utils/arial.ttf", fontsize) |
| 30 | txt = "55" |
| 31 | while min(get_text_size(txt, image, font)) < 0.03*image_size[0]: |
| 32 | # iterate until the text size is just larger than the criteria |
| 33 | fontsize += 1 |
| 34 | font = ImageFont.truetype("data/utils/arial.ttf", fontsize) |
| 35 | |
| 36 | text_size_2digits = get_text_size('55', image, font) |
| 37 | text_size_1digit = get_text_size('5', image, font) |
| 38 | text_size = { |
| 39 | 1: text_size_1digit, |
| 40 | 2: text_size_2digits |
| 41 | } |
| 42 | |
| 43 | # draw the starting point of positive traces on image |
| 44 | num_pos = pos_traces.shape[2] |
| 45 | pos_idx = torch.arange(num_pos) |
| 46 | pos_traces_to_mark = pos_traces |
| 47 | |
| 48 | # random sample at most 10 negative traces |
| 49 | num_neg = neg_traces.shape[2] |
| 50 | neg_idx = torch.arange(num_neg) |
| 51 | neg_traces_to_mark = neg_traces |
| 52 | |
| 53 | num_traces_total = pos_traces_to_mark.shape[2] + neg_traces_to_mark.shape[2] |
| 54 | # shuffle the indices |
| 55 | all_idx = torch.randperm(num_traces_total) |
| 56 | |
| 57 | pos_mark_ids = []; neg_mark_ids = [] |
| 58 | pos_traces_som = {} |
| 59 | for i in range(pos_traces_to_mark.shape[2]): |
| 60 | pos = pos_traces_to_mark[:,:,i] |
| 61 | mark_id = all_idx[i].item() |
| 62 | text_size = get_text_size(str(mark_id+1), image, font) |
| 63 | if draw_som_positive: |
no test coverage detected