MCPcopy
hub / github.com/msracver/Deformable-ConvNets / get_batch

Method get_batch

faster_rcnn/core/loader.py:223–256  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

221 return 0
222
223 def get_batch(self):
224 # slice roidb
225 cur_from = self.cur
226 cur_to = min(cur_from + self.batch_size, self.size)
227 roidb = [self.roidb[self.index[i]] for i in range(cur_from, cur_to)]
228
229 # decide multi device slices
230 work_load_list = self.work_load_list
231 ctx = self.ctx
232 if work_load_list is None:
233 work_load_list = [1] * len(ctx)
234 assert isinstance(work_load_list, list) and len(work_load_list) == len(ctx), \
235 "Invalid settings for work load. "
236 slices = _split_input_slice(self.batch_size, work_load_list)
237
238 # get each device
239 data_list = []
240 label_list = []
241 for islice in slices:
242 iroidb = [roidb[i] for i in range(islice.start, islice.stop)]
243 data, label = get_rcnn_batch(iroidb, self.cfg)
244 data_list.append(data)
245 label_list.append(label)
246
247 all_data = dict()
248 for key in data_list[0].keys():
249 all_data[key] = tensor_vstack([batch[key] for batch in data_list])
250
251 all_label = dict()
252 for key in label_list[0].keys():
253 all_label[key] = tensor_vstack([batch[key] for batch in label_list])
254
255 self.data = [mx.nd.array(all_data[name]) for name in self.data_name]
256 self.label = [mx.nd.array(all_label[name]) for name in self.label_name]
257
258 def get_batch_individual(self):
259 # slice roidb

Callers

nothing calls this directly

Calls 2

get_rcnn_batchFunction · 0.90
tensor_vstackFunction · 0.90

Tested by

no test coverage detected