(id, samples, tokenizer, max_seq_len, num_workers, num_samples, output_path, args)
| 106 | |
| 107 | |
| 108 | def process(id, samples, tokenizer, max_seq_len, num_workers, num_samples, output_path, args): |
| 109 | cnt = 0 |
| 110 | sample_nums = num_samples |
| 111 | all_lines = [] |
| 112 | dataset = [] |
| 113 | train_fname = os.path.join(output_path, f"train_data_{id}.pt") |
| 114 | while cnt < sample_nums // num_workers: |
| 115 | index = id |
| 116 | single_process_length = len(samples) // num_workers |
| 117 | #### 统计所有句子的长度 |
| 118 | lengths = [] |
| 119 | chunk_size = 1 |
| 120 | all_lines_shard = samples[index * single_process_length:(index + 1) * single_process_length] if index < num_workers - 1 \ |
| 121 | else samples[index * single_process_length:] |
| 122 | all_lines_chunk_list = [all_lines_shard[i:i + chunk_size] for i in range(0, len(all_lines_shard), chunk_size)] |
| 123 | for i in tqdm(range(len(all_lines_chunk_list))): |
| 124 | encoded_batch = tokenizer.batch_encode_plus(all_lines_chunk_list[i], padding=False) |
| 125 | for j in range(len(encoded_batch["input_ids"])): |
| 126 | lengths.append(len(encoded_batch["input_ids"][j])) |
| 127 | all_lines_and_length = [] |
| 128 | for i, item in tqdm(enumerate(all_lines_shard)): |
| 129 | if lengths[i] < max_seq_len - 10: ###只有小于maxlen的才可以被处理 |
| 130 | all_lines_and_length.append((item, lengths[i])) |
| 131 | |
| 132 | pool = all_lines_and_length |
| 133 | min_threshold = min(lengths) |
| 134 | pad_count = 0 |
| 135 | tot = 0 |
| 136 | pbar = tqdm(total=len(pool), desc=f"Processing {id}, Concating dataset", disable=(id != 0)) |
| 137 | while pool: |
| 138 | ptr = 0 |
| 139 | buffer_len = 0 |
| 140 | buffer = [] |
| 141 | while ptr < len(pool) and (max_seq_len - buffer_len) > min_threshold: |
| 142 | if pool[ptr][1] + buffer_len < max_seq_len - 10: ####至少留10个padding |
| 143 | buffer_len += pool[ptr][1] |
| 144 | buffer.append(pool[ptr][0]) |
| 145 | pool.pop(ptr) |
| 146 | pbar.update(1) |
| 147 | else: |
| 148 | ptr += 1 |
| 149 | buffer_text = "".join(buffer) |
| 150 | output = buffer_text |
| 151 | pad_count += (max_seq_len - buffer_len) |
| 152 | tot += 1 |
| 153 | assert output.count("<_user>") == output.count("<_bot>") == output.count("<_end>") |
| 154 | if output.count("<_user>") == output.count("<_bot>") == output.count("<_end>") and output.count( |
| 155 | "<_user>") >= 1: |
| 156 | all_lines.append(output) |
| 157 | cnt += 1 |
| 158 | if cnt >= sample_nums // num_workers: break |
| 159 | pbar.close() |
| 160 | for line in tqdm(all_lines, desc="Convert token ids", disable=(id != 0)): |
| 161 | tokens = process_concat_data(line, tokenizer, max_seq_len, args) |
| 162 | dataset.append(tokens) |
| 163 | train_dataset = PromptDataset(dataset) |
| 164 | torch.save(train_dataset, train_fname) |
| 165 | return dataset |
no test coverage detected