MCPcopy
hub / github.com/Tele-AI/Telechat / process

Function process

deepspeed-telechat/utils/data/data_utils.py:108–165  ·  view source on GitHub ↗
(id, samples, tokenizer, max_seq_len, num_workers, num_samples, output_path, args)

Source from the content-addressed store, hash-verified

106
107
108def process(id, samples, tokenizer, max_seq_len, num_workers, num_samples, output_path, args):
109 cnt = 0
110 sample_nums = num_samples
111 all_lines = []
112 dataset = []
113 train_fname = os.path.join(output_path, f"train_data_{id}.pt")
114 while cnt < sample_nums // num_workers:
115 index = id
116 single_process_length = len(samples) // num_workers
117 #### 统计所有句子的长度
118 lengths = []
119 chunk_size = 1
120 all_lines_shard = samples[index * single_process_length:(index + 1) * single_process_length] if index < num_workers - 1 \
121 else samples[index * single_process_length:]
122 all_lines_chunk_list = [all_lines_shard[i:i + chunk_size] for i in range(0, len(all_lines_shard), chunk_size)]
123 for i in tqdm(range(len(all_lines_chunk_list))):
124 encoded_batch = tokenizer.batch_encode_plus(all_lines_chunk_list[i], padding=False)
125 for j in range(len(encoded_batch["input_ids"])):
126 lengths.append(len(encoded_batch["input_ids"][j]))
127 all_lines_and_length = []
128 for i, item in tqdm(enumerate(all_lines_shard)):
129 if lengths[i] < max_seq_len - 10: ###只有小于maxlen的才可以被处理
130 all_lines_and_length.append((item, lengths[i]))
131
132 pool = all_lines_and_length
133 min_threshold = min(lengths)
134 pad_count = 0
135 tot = 0
136 pbar = tqdm(total=len(pool), desc=f"Processing {id}, Concating dataset", disable=(id != 0))
137 while pool:
138 ptr = 0
139 buffer_len = 0
140 buffer = []
141 while ptr < len(pool) and (max_seq_len - buffer_len) > min_threshold:
142 if pool[ptr][1] + buffer_len < max_seq_len - 10: ####至少留10个padding
143 buffer_len += pool[ptr][1]
144 buffer.append(pool[ptr][0])
145 pool.pop(ptr)
146 pbar.update(1)
147 else:
148 ptr += 1
149 buffer_text = "".join(buffer)
150 output = buffer_text
151 pad_count += (max_seq_len - buffer_len)
152 tot += 1
153 assert output.count("<_user>") == output.count("<_bot>") == output.count("<_end>")
154 if output.count("<_user>") == output.count("<_bot>") == output.count("<_end>") and output.count(
155 "<_user>") >= 1:
156 all_lines.append(output)
157 cnt += 1
158 if cnt >= sample_nums // num_workers: break
159 pbar.close()
160 for line in tqdm(all_lines, desc="Convert token ids", disable=(id != 0)):
161 tokens = process_concat_data(line, tokenizer, max_seq_len, args)
162 dataset.append(tokens)
163 train_dataset = PromptDataset(dataset)
164 torch.save(train_dataset, train_fname)
165 return dataset

Callers 1

create_prompt_datasetFunction · 0.85

Calls 5

process_concat_dataFunction · 0.85
PromptDatasetClass · 0.85
appendMethod · 0.45
popMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected