(self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
aspect_ratios: dict,
drop_last: bool = False,
config=None,
valid_num=0, # take as valid aspect-ratio when sample number >= valid_num
**kwargs)
| 20 | """ |
| 21 | |
| 22 | def __init__(self, |
| 23 | sampler: Sampler, |
| 24 | dataset: Dataset, |
| 25 | batch_size: int, |
| 26 | aspect_ratios: dict, |
| 27 | drop_last: bool = False, |
| 28 | config=None, |
| 29 | valid_num=0, # take as valid aspect-ratio when sample number >= valid_num |
| 30 | **kwargs) -> None: |
| 31 | if not isinstance(sampler, Sampler): |
| 32 | raise TypeError('sampler should be an instance of ``Sampler``, ' |
| 33 | f'but got {sampler}') |
| 34 | if not isinstance(batch_size, int) or batch_size <= 0: |
| 35 | raise ValueError('batch_size should be a positive integer value, ' |
| 36 | f'but got batch_size={batch_size}') |
| 37 | self.sampler = sampler |
| 38 | self.dataset = dataset |
| 39 | self.batch_size = batch_size |
| 40 | self.aspect_ratios = aspect_ratios |
| 41 | self.drop_last = drop_last |
| 42 | self.ratio_nums_gt = kwargs.get('ratio_nums', None) |
| 43 | self.config = config |
| 44 | assert self.ratio_nums_gt |
| 45 | # buckets for each aspect ratio |
| 46 | self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios.keys()} |
| 47 | self.current_available_bucket_keys = [str(k) for k, v in self.ratio_nums_gt.items() if v >= valid_num] |
| 48 | logger = get_root_logger() if config is None else get_root_logger(os.path.join(config.work_dir, 'train_log.log')) |
| 49 | logger.warning(f"Using valid_num={valid_num} in config file. Available {len(self.current_available_bucket_keys)} aspect_ratios: {self.current_available_bucket_keys}") |
| 50 | |
| 51 | def __iter__(self) -> Sequence[int]: |
| 52 | for idx in self.sampler: |
no test coverage detected