(self, config, mode, logger, seed=None)
| 171 | |
| 172 | class SimpleDataSet(Dataset): |
| 173 | def __init__(self, config, mode, logger, seed=None): |
| 174 | super(SimpleDataSet, self).__init__() |
| 175 | self.logger = logger |
| 176 | self.mode = mode.lower() |
| 177 | |
| 178 | global_config = config["Global"] |
| 179 | dataset_config = config[mode]["dataset"] |
| 180 | loader_config = config[mode]["loader"] |
| 181 | |
| 182 | self.delimiter = dataset_config.get("delimiter", "\t") |
| 183 | label_file_list = dataset_config.pop("label_file_list") |
| 184 | data_source_num = len(label_file_list) |
| 185 | ratio_list = dataset_config.get("ratio_list", 1.0) |
| 186 | if isinstance(ratio_list, (float, int)): |
| 187 | ratio_list = [float(ratio_list)] * int(data_source_num) |
| 188 | self.label_file_list = label_file_list |
| 189 | self.ratio_list = ratio_list |
| 190 | |
| 191 | assert ( |
| 192 | len(ratio_list) == data_source_num |
| 193 | ), "The length of ratio_list should be the same as the file_list." |
| 194 | self.data_dir = dataset_config["data_dir"] |
| 195 | self.do_shuffle = loader_config["shuffle"] |
| 196 | self.seed = seed |
| 197 | self.need_reset = True in [x < 1 for x in ratio_list] |
| 198 | |
| 199 | logger.info("Initialize indexs of datasets:%s" % label_file_list) |
| 200 | |
| 201 | if self.need_reset: |
| 202 | # Pre-load all lines once (immutable, never re-read from disk). |
| 203 | # Per-epoch ratio sampling is done via _index_map (virtual idx -> global idx). |
| 204 | self._all_lines, self.file_boundaries = self._load_all_lines( |
| 205 | label_file_list |
| 206 | ) |
| 207 | self._index_map = self._generate_index_map(seed) |
| 208 | self._cached_epoch = seed if seed is not None else 0 |
| 209 | # data_lines / data_idx_order_list kept for API compat but NOT used in __getitem__ |
| 210 | self.data_lines = self._all_lines |
| 211 | self.data_idx_order_list = list(range(len(self._index_map))) |
| 212 | else: |
| 213 | self._all_lines = None |
| 214 | self._index_map = None |
| 215 | self._cached_epoch = None |
| 216 | self.file_boundaries = None |
| 217 | self.data_lines = self.get_image_info_list(label_file_list, ratio_list) |
| 218 | self.data_idx_order_list = list(range(len(self.data_lines))) |
| 219 | if self.mode == "train" and self.do_shuffle: |
| 220 | self.shuffle_data_random() |
| 221 | |
| 222 | # Shared epoch value: workers read this via shared memory to detect epoch changes |
| 223 | self._shared_epoch = multiprocessing.Value("i", seed if seed is not None else 0) |
| 224 | |
| 225 | self.ops = create_operators(dataset_config["transforms"], global_config) |
| 226 | self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2) |
| 227 | |
| 228 | # ------------------------------------------------------------------ # |
| 229 | # Data loading helpers |
no test coverage detected