MCPcopy
hub / github.com/Tele-AI/Telechat / get_weight_data

Function get_weight_data

deepspeed-telechat/utils/data/data_utils.py:55–86  ·  view source on GitHub ↗
(current_dataset, dataset_weight)

Source from the content-addressed store, hash-verified

53 }
54
55def get_weight_data(current_dataset, dataset_weight):
56 dataset = []
57 all_lines = []
58 for i, tmp_data in enumerate(current_dataset):
59 if dataset_weight < 1.0 and random.random() > dataset_weight: continue
60 input = tmp_data['input']
61 input = re.sub(r"^<_user>", "", input, flags=re.S)
62 input = "<_user>" + input
63 output = tmp_data['output']
64 output = re.sub(r"^<_bot>", "", output, flags=re.S)
65 if "<_bot>" in input: ### multiturn
66 concat_line = ""
67 input_turns = input.split("<_user>")[1:]
68 for item in input_turns:
69 if "<_bot>" in item:
70 concat_line += "<_user>" + item + "<_end>"
71 else:
72 concat_line += "<_user>" + item + "<_bot>"
73 concat_line += output + "<_end>"
74 else: ####single turn
75 concat_line = str(input) + "<_bot>" + str(output) + "<_end>"
76 assert concat_line.count("<_user>") == concat_line.count("<_bot>") == concat_line.count("<_end>")
77 if dataset_weight < 1.0:
78 all_lines.append(concat_line)
79 else:
80 weight_integer = math.floor(dataset_weight)
81 weight_decimal = dataset_weight - weight_integer
82 for i in range(math.floor(dataset_weight)):
83 all_lines.append(concat_line)
84 if random.random() < weight_decimal:
85 all_lines.append(concat_line)
86 return all_lines
87
88def create_dataset( dataset_name, dataset_weight, output_path, seed):
89 raw_dataset = get_raw_dataset(dataset_name, output_path, seed)

Callers 1

create_datasetFunction · 0.85

Calls 1

appendMethod · 0.45

Tested by

no test coverage detected