(self, config, mode, logger, seed=None)
| 20 | |
| 21 | class PGDataSet(Dataset): |
| 22 | def __init__(self, config, mode, logger, seed=None): |
| 23 | super(PGDataSet, self).__init__() |
| 24 | |
| 25 | self.logger = logger |
| 26 | self.seed = seed |
| 27 | self.mode = mode |
| 28 | global_config = config["Global"] |
| 29 | dataset_config = config[mode]["dataset"] |
| 30 | loader_config = config[mode]["loader"] |
| 31 | |
| 32 | self.delimiter = dataset_config.get("delimiter", "\t") |
| 33 | label_file_list = dataset_config.pop("label_file_list") |
| 34 | data_source_num = len(label_file_list) |
| 35 | ratio_list = dataset_config.get("ratio_list", [1.0]) |
| 36 | if isinstance(ratio_list, (float, int)): |
| 37 | ratio_list = [float(ratio_list)] * int(data_source_num) |
| 38 | assert ( |
| 39 | len(ratio_list) == data_source_num |
| 40 | ), "The length of ratio_list should be the same as the file_list." |
| 41 | self.data_dir = dataset_config["data_dir"] |
| 42 | self.do_shuffle = loader_config["shuffle"] |
| 43 | |
| 44 | logger.info("Initialize indexes of datasets:%s" % label_file_list) |
| 45 | self.data_lines = self.get_image_info_list(label_file_list, ratio_list) |
| 46 | self.data_idx_order_list = list(range(len(self.data_lines))) |
| 47 | if mode.lower() == "train": |
| 48 | self.shuffle_data_random() |
| 49 | |
| 50 | self.ops = create_operators(dataset_config["transforms"], global_config) |
| 51 | |
| 52 | self.need_reset = True in [x < 1 for x in ratio_list] |
| 53 | |
| 54 | def shuffle_data_random(self): |
| 55 | if self.do_shuffle: |
nothing calls this directly
no test coverage detected