(current_dataset, dataset_weight)
| 53 | } |
| 54 | |
| 55 | def get_weight_data(current_dataset, dataset_weight): |
| 56 | dataset = [] |
| 57 | all_lines = [] |
| 58 | for i, tmp_data in enumerate(current_dataset): |
| 59 | if dataset_weight < 1.0 and random.random() > dataset_weight: continue |
| 60 | input = tmp_data['input'] |
| 61 | input = re.sub(r"^<_user>", "", input, flags=re.S) |
| 62 | input = "<_user>" + input |
| 63 | output = tmp_data['output'] |
| 64 | output = re.sub(r"^<_bot>", "", output, flags=re.S) |
| 65 | if "<_bot>" in input: ### multiturn |
| 66 | concat_line = "" |
| 67 | input_turns = input.split("<_user>")[1:] |
| 68 | for item in input_turns: |
| 69 | if "<_bot>" in item: |
| 70 | concat_line += "<_user>" + item + "<_end>" |
| 71 | else: |
| 72 | concat_line += "<_user>" + item + "<_bot>" |
| 73 | concat_line += output + "<_end>" |
| 74 | else: ####single turn |
| 75 | concat_line = str(input) + "<_bot>" + str(output) + "<_end>" |
| 76 | assert concat_line.count("<_user>") == concat_line.count("<_bot>") == concat_line.count("<_end>") |
| 77 | if dataset_weight < 1.0: |
| 78 | all_lines.append(concat_line) |
| 79 | else: |
| 80 | weight_integer = math.floor(dataset_weight) |
| 81 | weight_decimal = dataset_weight - weight_integer |
| 82 | for i in range(math.floor(dataset_weight)): |
| 83 | all_lines.append(concat_line) |
| 84 | if random.random() < weight_decimal: |
| 85 | all_lines.append(concat_line) |
| 86 | return all_lines |
| 87 | |
| 88 | def create_dataset( dataset_name, dataset_weight, output_path, seed): |
| 89 | raw_dataset = get_raw_dataset(dataset_name, output_path, seed) |
no test coverage detected