(
self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
preprocessed_path,
num_data,
)
| 267 | """Dataset for supervised fine-tuning.""" |
| 268 | |
| 269 | def __init__( |
| 270 | self, |
| 271 | data_path: str, |
| 272 | tokenizer: transformers.PreTrainedTokenizer, |
| 273 | preprocessed_path, |
| 274 | num_data, |
| 275 | ): |
| 276 | super(SupervisedDataset, self).__init__() |
| 277 | |
| 278 | # save to file |
| 279 | # Make sure only the first process is processing the dataset |
| 280 | if dist.get_rank() != 0: |
| 281 | dist.barrier() |
| 282 | self.preprocessed_path = preprocessed_path |
| 283 | if os.path.exists(self.preprocessed_path): |
| 284 | logging.warning("loading from preprocessed data") |
| 285 | with open(self.preprocessed_path, "r") as f: |
| 286 | data_dict = json.load(f) |
| 287 | if dist.get_rank() == 0: |
| 288 | dist.barrier() |
| 289 | else: |
| 290 | if not os.path.exists("preprocessed_data"): |
| 291 | os.mkdir("preprocessed_data") |
| 292 | assert dist.get_rank() == 0, "Only the first process should process" |
| 293 | logging.warning("Loading data...") |
| 294 | list_data_dict = json.load(open(data_path, "r")) |
| 295 | |
| 296 | logging.warning("Formatting inputs...") |
| 297 | sources = [] |
| 298 | |
| 299 | sources = [example["conversations"] for example in list_data_dict] |
| 300 | |
| 301 | data_dict = preprocess(sources, tokenizer) |
| 302 | json_data_dict = json.dumps(data_dict) |
| 303 | |
| 304 | # Remember to close file to avoid concurrent r/w |
| 305 | with open(self.preprocessed_path, "w") as f: |
| 306 | f.write(json_data_dict) |
| 307 | |
| 308 | # Release barrier |
| 309 | dist.barrier() |
| 310 | |
| 311 | if num_data != -1: |
| 312 | data_dict["input_ids"] = data_dict["input_ids"][:num_data] |
| 313 | data_dict["labels"] = data_dict["labels"][:num_data] |
| 314 | |
| 315 | # Shuffle data to see more conversations, if only train on partial data |
| 316 | temp = list(zip(data_dict["input_ids"], data_dict["labels"])) |
| 317 | random.shuffle(temp) |
| 318 | res1, res2 = zip(*temp) |
| 319 | data_dict["input_ids"], data_dict["labels"] = list(res1), list(res2) |
| 320 | |
| 321 | # Dacheng: Get rid of short QA pair |
| 322 | self.input_ids = copy.deepcopy(data_dict["input_ids"]) |
| 323 | self.labels = copy.deepcopy(data_dict["labels"]) |
| 324 | length_arr = defaultdict(int) |
| 325 | for idx, (input, label) in enumerate( |
| 326 | zip(data_dict["input_ids"], data_dict["labels"]) |
nothing calls this directly
no test coverage detected