Return a training dataflow. Each datapoint consists of the following: An image: (h, w, 3), 1 or more pairs of (anchor_labels, anchor_boxes): anchor_labels: (h', w', NA) anchor_boxes: (h', w', NA, 4) gt_boxes: (N, 4) gt_labels: (N,) If MODE_MASK, gt_masks: (N, h,
()
| 325 | |
| 326 | |
| 327 | def get_train_dataflow(): |
| 328 | """ |
| 329 | Return a training dataflow. Each datapoint consists of the following: |
| 330 | |
| 331 | An image: (h, w, 3), |
| 332 | |
| 333 | 1 or more pairs of (anchor_labels, anchor_boxes): |
| 334 | anchor_labels: (h', w', NA) |
| 335 | anchor_boxes: (h', w', NA, 4) |
| 336 | |
| 337 | gt_boxes: (N, 4) |
| 338 | gt_labels: (N,) |
| 339 | |
| 340 | If MODE_MASK, gt_masks: (N, h, w) |
| 341 | """ |
| 342 | roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN)) |
| 343 | print_class_histogram(roidbs) |
| 344 | |
| 345 | # Filter out images that have no gt boxes, but this filter shall not be applied for testing. |
| 346 | # The model does support training with empty images, but it is not useful for COCO. |
| 347 | num = len(roidbs) |
| 348 | if cfg.DATA.FILTER_EMPTY_ANNOTATIONS: |
| 349 | roidbs = list(filter(lambda img: len(img["boxes"][img["is_crowd"] == 0]) > 0, roidbs)) |
| 350 | logger.info( |
| 351 | "Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format( |
| 352 | num - len(roidbs), len(roidbs) |
| 353 | ) |
| 354 | ) |
| 355 | |
| 356 | ds = DataFromList(roidbs, shuffle=True) |
| 357 | |
| 358 | preprocess = TrainingDataPreprocessor(cfg) |
| 359 | |
| 360 | if cfg.DATA.NUM_WORKERS > 0: |
| 361 | if cfg.TRAINER == "horovod": |
| 362 | buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer |
| 363 | ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) |
| 364 | # MPI does not like fork() |
| 365 | else: |
| 366 | buffer_size = cfg.DATA.NUM_WORKERS * 20 |
| 367 | ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) |
| 368 | else: |
| 369 | ds = MapData(ds, preprocess) |
| 370 | return ds |
| 371 | |
| 372 | |
| 373 | def get_eval_dataflow(name, shard=0, num_shards=1): |
no test coverage detected