(self)
| 221 | return 0 |
| 222 | |
| 223 | def get_batch(self): |
| 224 | # slice roidb |
| 225 | cur_from = self.cur |
| 226 | cur_to = min(cur_from + self.batch_size, self.size) |
| 227 | roidb = [self.roidb[self.index[i]] for i in range(cur_from, cur_to)] |
| 228 | |
| 229 | # decide multi device slices |
| 230 | work_load_list = self.work_load_list |
| 231 | ctx = self.ctx |
| 232 | if work_load_list is None: |
| 233 | work_load_list = [1] * len(ctx) |
| 234 | assert isinstance(work_load_list, list) and len(work_load_list) == len(ctx), \ |
| 235 | "Invalid settings for work load. " |
| 236 | slices = _split_input_slice(self.batch_size, work_load_list) |
| 237 | |
| 238 | # get each device |
| 239 | data_list = [] |
| 240 | label_list = [] |
| 241 | for islice in slices: |
| 242 | iroidb = [roidb[i] for i in range(islice.start, islice.stop)] |
| 243 | data, label = get_rcnn_batch(iroidb, self.cfg) |
| 244 | data_list.append(data) |
| 245 | label_list.append(label) |
| 246 | |
| 247 | all_data = dict() |
| 248 | for key in data_list[0].keys(): |
| 249 | all_data[key] = tensor_vstack([batch[key] for batch in data_list]) |
| 250 | |
| 251 | all_label = dict() |
| 252 | for key in label_list[0].keys(): |
| 253 | all_label[key] = tensor_vstack([batch[key] for batch in label_list]) |
| 254 | |
| 255 | self.data = [mx.nd.array(all_data[name]) for name in self.data_name] |
| 256 | self.label = [mx.nd.array(all_label[name]) for name in self.label_name] |
| 257 | |
| 258 | def get_batch_individual(self): |
| 259 | # slice roidb |
nothing calls this directly
no test coverage detected