MCPcopy
hub / github.com/microsoft/Magma / som_prompting

Function som_prompting

data/utils/som_tom.py:6–78  ·  view source on GitHub ↗

draw marks on the image

(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False)

Source from the content-addressed store, hash-verified

4import matplotlib.pyplot as plt
5
6def som_prompting(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False):
7 """
8 draw marks on the image
9 """
10 image_size = image.size
11 draw = ImageDraw.Draw(image)
12
13 def get_text_size(text, image, font):
14 im = Image.new('RGB', (image.width, image.height))
15 draw = ImageDraw.Draw(im)
16 _, _, width, height = draw.textbbox((0, 0), text=text, font=font)
17 return width, height
18
19 def expand_bbox(bbox):
20 x1, y1, x2, y2 = bbox
21 return [x1-4, y1-4, x2+4, y2+4]
22
23 def draw_marks(draw, points, text_size, id, font_size):
24 txt = str(id)
25 draw.ellipse(((points[0]-max(text_size)//2-1, points[1]-max(text_size)//2-1, points[0]+max(text_size)//2+1, points[1]+max(text_size)//2+1)), fill='red')
26 draw.text((points[0]-text_size[0] // 2, points[1]-text_size[1] // 2-3), txt, fill='white', font=font_size)
27
28 fontsize = 1
29 font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
30 txt = "55"
31 while min(get_text_size(txt, image, font)) < 0.03*image_size[0]:
32 # iterate until the text size is just larger than the criteria
33 fontsize += 1
34 font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
35
36 text_size_2digits = get_text_size('55', image, font)
37 text_size_1digit = get_text_size('5', image, font)
38 text_size = {
39 1: text_size_1digit,
40 2: text_size_2digits
41 }
42
43 # draw the starting point of positive traces on image
44 num_pos = pos_traces.shape[2]
45 pos_idx = torch.arange(num_pos)
46 pos_traces_to_mark = pos_traces
47
48 # random sample at most 10 negative traces
49 num_neg = neg_traces.shape[2]
50 neg_idx = torch.arange(num_neg)
51 neg_traces_to_mark = neg_traces
52
53 num_traces_total = pos_traces_to_mark.shape[2] + neg_traces_to_mark.shape[2]
54 # shuffle the indices
55 all_idx = torch.randperm(num_traces_total)
56
57 pos_mark_ids = []; neg_mark_ids = []
58 pos_traces_som = {}
59 for i in range(pos_traces_to_mark.shape[2]):
60 pos = pos_traces_to_mark[:,:,i]
61 mark_id = all_idx[i].item()
62 text_size = get_text_size(str(mark_id+1), image, font)
63 if draw_som_positive:

Callers 4

som_tomFunction · 0.90
_construct_conv_somMethod · 0.90
_construct_convMethod · 0.90
__call__Method · 0.90

Calls 2

get_text_sizeFunction · 0.85
draw_marksFunction · 0.85

Tested by

no test coverage detected