MCPcopy Index your code
hub / github.com/NVIDIA/FastPhotoStyle / segment_this_img

Function segment_this_img

demo_with_ade20k_ssn.py:83–114  ·  view source on GitHub ↗
(f)

Source from the content-addressed store, hash-verified

81
82
83def segment_this_img(f):
84 img = imread(f, mode='RGB')
85 img = img[:, :, ::-1] # BGR to RGB!!!
86 ori_height, ori_width, _ = img.shape
87 img_resized_list = []
88 for this_short_size in args.imgSize:
89 scale = this_short_size / float(min(ori_height, ori_width))
90 target_height, target_width = int(ori_height * scale), int(ori_width * scale)
91 target_height = round2nearest_multiple(target_height, args.padding_constant)
92 target_width = round2nearest_multiple(target_width, args.padding_constant)
93 img_resized = cv2.resize(img.copy(), (target_width, target_height))
94 img_resized = img_resized.astype(np.float32)
95 img_resized = img_resized.transpose((2, 0, 1))
96 img_resized = transform(torch.from_numpy(img_resized))
97 img_resized = torch.unsqueeze(img_resized, 0)
98 img_resized_list.append(img_resized)
99 input = dict()
100 input['img_ori'] = img.copy()
101 input['img_data'] = [x.contiguous() for x in img_resized_list]
102 segSize = (img.shape[0],img.shape[1])
103 with torch.no_grad():
104 pred = torch.zeros(1, args.num_class, segSize[0], segSize[1])
105 for timg in img_resized_list:
106 feed_dict = dict()
107 feed_dict['img_data'] = timg.cuda()
108 feed_dict = async_copy_to(feed_dict, args.gpu_id)
109 # forward pass
110 pred_tmp = segmentation_module(feed_dict, segSize=segSize)
111 pred = pred + pred_tmp.cpu() / len(args.imgSize)
112 _, preds = torch.max(pred, dim=1)
113 preds = as_numpy(preds.squeeze(0))
114 return preds
115
116
117cont_seg = segment_this_img(args.content_image_path)

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected