MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / get_train_dataflow

Function get_train_dataflow

examples/FasterRCNN/data.py:327–370  ·  view source on GitHub ↗

Return a training dataflow. Each datapoint consists of the following: An image: (h, w, 3), 1 or more pairs of (anchor_labels, anchor_boxes): anchor_labels: (h', w', NA) anchor_boxes: (h', w', NA, 4) gt_boxes: (N, 4) gt_labels: (N,) If MODE_MASK, gt_masks: (N, h,

()

Source from the content-addressed store, hash-verified

325
326
327def get_train_dataflow():
328 """
329 Return a training dataflow. Each datapoint consists of the following:
330
331 An image: (h, w, 3),
332
333 1 or more pairs of (anchor_labels, anchor_boxes):
334 anchor_labels: (h', w', NA)
335 anchor_boxes: (h', w', NA, 4)
336
337 gt_boxes: (N, 4)
338 gt_labels: (N,)
339
340 If MODE_MASK, gt_masks: (N, h, w)
341 """
342 roidbs = list(itertools.chain.from_iterable(DatasetRegistry.get(x).training_roidbs() for x in cfg.DATA.TRAIN))
343 print_class_histogram(roidbs)
344
345 # Filter out images that have no gt boxes, but this filter shall not be applied for testing.
346 # The model does support training with empty images, but it is not useful for COCO.
347 num = len(roidbs)
348 if cfg.DATA.FILTER_EMPTY_ANNOTATIONS:
349 roidbs = list(filter(lambda img: len(img["boxes"][img["is_crowd"] == 0]) > 0, roidbs))
350 logger.info(
351 "Filtered {} images which contain no non-crowd groudtruth boxes. Total #images for training: {}".format(
352 num - len(roidbs), len(roidbs)
353 )
354 )
355
356 ds = DataFromList(roidbs, shuffle=True)
357
358 preprocess = TrainingDataPreprocessor(cfg)
359
360 if cfg.DATA.NUM_WORKERS > 0:
361 if cfg.TRAINER == "horovod":
362 buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer
363 ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
364 # MPI does not like fork()
365 else:
366 buffer_size = cfg.DATA.NUM_WORKERS * 20
367 ds = MultiProcessMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
368 else:
369 ds = MapData(ds, preprocess)
370 return ds
371
372
373def get_eval_dataflow(name, shard=0, num_shards=1):

Callers 3

do_visualizeFunction · 0.90
train.pyFile · 0.90
data.pyFile · 0.85

Calls 8

DataFromListClass · 0.90
MultiThreadMapDataClass · 0.90
MapDataClass · 0.90
print_class_histogramFunction · 0.85
formatMethod · 0.80
training_roidbsMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected