| 92 | return train_dataset |
| 93 | |
| 94 | def process_concat_data(text, tokenizer, max_seq_len, args): |
| 95 | texts = text.split("<_end>") |
| 96 | sentence_ids = [] |
| 97 | for text in texts: |
| 98 | if text != '': |
| 99 | input, output = text.split("<_bot>") |
| 100 | input = re.sub(r"^<_user>", "", input, flags=re.S) |
| 101 | input_ids = [args.user_token_id] + tokenizer(input)["input_ids"] |
| 102 | output_ids = [args.bot_token_id] + tokenizer(output)["input_ids"] + [args.end_token_id] |
| 103 | sentence_ids += (input_ids + output_ids) |
| 104 | sentence_ids = [3] * (max_seq_len - len(sentence_ids)) + sentence_ids |
| 105 | return {"input_ids": torch.tensor(sentence_ids), "attention_mask": torch.ones(len(sentence_ids))} |
| 106 | |
| 107 | |
| 108 | def process(id, samples, tokenizer, max_seq_len, num_workers, num_samples, output_path, args): |