MCPcopy
hub / github.com/lm-sys/FastChat / __init__

Method __init__

fastchat/train/train_flant5.py:269–340  ·  view source on GitHub ↗
(
        self,
        data_path: str,
        tokenizer: transformers.PreTrainedTokenizer,
        preprocessed_path,
        num_data,
    )

Source from the content-addressed store, hash-verified

267 """Dataset for supervised fine-tuning."""
268
269 def __init__(
270 self,
271 data_path: str,
272 tokenizer: transformers.PreTrainedTokenizer,
273 preprocessed_path,
274 num_data,
275 ):
276 super(SupervisedDataset, self).__init__()
277
278 # save to file
279 # Make sure only the first process is processing the dataset
280 if dist.get_rank() != 0:
281 dist.barrier()
282 self.preprocessed_path = preprocessed_path
283 if os.path.exists(self.preprocessed_path):
284 logging.warning("loading from preprocessed data")
285 with open(self.preprocessed_path, "r") as f:
286 data_dict = json.load(f)
287 if dist.get_rank() == 0:
288 dist.barrier()
289 else:
290 if not os.path.exists("preprocessed_data"):
291 os.mkdir("preprocessed_data")
292 assert dist.get_rank() == 0, "Only the first process should process"
293 logging.warning("Loading data...")
294 list_data_dict = json.load(open(data_path, "r"))
295
296 logging.warning("Formatting inputs...")
297 sources = []
298
299 sources = [example["conversations"] for example in list_data_dict]
300
301 data_dict = preprocess(sources, tokenizer)
302 json_data_dict = json.dumps(data_dict)
303
304 # Remember to close file to avoid concurrent r/w
305 with open(self.preprocessed_path, "w") as f:
306 f.write(json_data_dict)
307
308 # Release barrier
309 dist.barrier()
310
311 if num_data != -1:
312 data_dict["input_ids"] = data_dict["input_ids"][:num_data]
313 data_dict["labels"] = data_dict["labels"][:num_data]
314
315 # Shuffle data to see more conversations, if only train on partial data
316 temp = list(zip(data_dict["input_ids"], data_dict["labels"]))
317 random.shuffle(temp)
318 res1, res2 = zip(*temp)
319 data_dict["input_ids"], data_dict["labels"] = list(res1), list(res2)
320
321 # Dacheng: Get rid of short QA pair
322 self.input_ids = copy.deepcopy(data_dict["input_ids"])
323 self.labels = copy.deepcopy(data_dict["labels"])
324 length_arr = defaultdict(int)
325 for idx, (input, label) in enumerate(
326 zip(data_dict["input_ids"], data_dict["labels"])

Callers

nothing calls this directly

Calls 2

writeMethod · 0.80
preprocessFunction · 0.70

Tested by

no test coverage detected