(self, config, mode, logger, seed=None)
| 61 | |
| 62 | class LaTeXOCRDataSet(Dataset): |
| 63 | def __init__(self, config, mode, logger, seed=None): |
| 64 | super(LaTeXOCRDataSet, self).__init__() |
| 65 | self.logger = logger |
| 66 | self.mode = mode.lower() |
| 67 | |
| 68 | global_config = config["Global"] |
| 69 | dataset_config = config[mode]["dataset"] |
| 70 | loader_config = config[mode]["loader"] |
| 71 | |
| 72 | pkl_path = dataset_config.pop("data") |
| 73 | self.data_dir = dataset_config["data_dir"] |
| 74 | self.min_dimensions = dataset_config.pop("min_dimensions") |
| 75 | self.max_dimensions = dataset_config.pop("max_dimensions") |
| 76 | self.batchsize = dataset_config.pop("batch_size_per_pair") |
| 77 | self.keep_smaller_batches = dataset_config.pop("keep_smaller_batches") |
| 78 | self.max_seq_len = global_config.pop("max_seq_len") |
| 79 | self.rec_char_dict_path = global_config.pop("rec_char_dict_path") |
| 80 | self.tokenizer = LatexOCRLabelEncode(self.rec_char_dict_path) |
| 81 | |
| 82 | with open(pkl_path, "rb") as file: |
| 83 | data = _restricted_pickle_load(file) |
| 84 | if not isinstance(data, dict): |
| 85 | raise pickle.UnpicklingError( |
| 86 | "LaTeXOCR dataset payload must deserialize to a dict" |
| 87 | ) |
| 88 | temp = {} |
| 89 | for k in data: |
| 90 | if ( |
| 91 | self.min_dimensions[0] <= k[0] <= self.max_dimensions[0] |
| 92 | and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1] |
| 93 | ): |
| 94 | temp[k] = data[k] |
| 95 | self.data = temp |
| 96 | self.do_shuffle = loader_config["shuffle"] |
| 97 | self.seed = seed |
| 98 | |
| 99 | if self.mode == "train" and self.do_shuffle: |
| 100 | random.seed(self.seed) |
| 101 | self.pairs = [] |
| 102 | for k in self.data: |
| 103 | info = np.array(self.data[k], dtype=object) |
| 104 | p = ( |
| 105 | paddle.randperm(len(info)) |
| 106 | if self.mode == "train" and self.do_shuffle |
| 107 | else paddle.arange(len(info)) |
| 108 | ) |
| 109 | for i in range(0, len(info), self.batchsize): |
| 110 | batch = info[p[i : i + self.batchsize]] |
| 111 | if len(batch.shape) == 1: |
| 112 | batch = batch[None, :] |
| 113 | if len(batch) < self.batchsize and not self.keep_smaller_batches: |
| 114 | continue |
| 115 | self.pairs.append(batch) |
| 116 | if self.do_shuffle: |
| 117 | self.pairs = np.random.permutation(np.array(self.pairs, dtype=object)) |
| 118 | else: |
| 119 | self.pairs = np.array(self.pairs, dtype=object) |
| 120 |
nothing calls this directly
no test coverage detected