| 66 | |
| 67 | |
| 68 | def sample_random_requests( |
| 69 | input_len: int, |
| 70 | output_len: int, |
| 71 | num_prompts: int, |
| 72 | range_ratio: float, |
| 73 | tokenizer: PreTrainedTokenizerBase, |
| 74 | dataset_path: str, |
| 75 | ) -> list[tuple[str, int, int]]: |
| 76 | |
| 77 | input_lens = np.random.randint( |
| 78 | max(int(input_len * range_ratio), 1), |
| 79 | input_len + 1, |
| 80 | size=num_prompts, |
| 81 | ) |
| 82 | output_lens = np.random.randint( |
| 83 | int(output_len * range_ratio), |
| 84 | output_len + 1, |
| 85 | size=num_prompts, |
| 86 | ) |
| 87 | |
| 88 | if True: |
| 89 | # Sample token ids from ShareGPT and repeat/truncate them to |
| 90 | # satisfy the input_lens |
| 91 | |
| 92 | # Load the dataset. |
| 93 | with open(dataset_path) as f: |
| 94 | dataset = json.load(f) |
| 95 | # Filter out the conversations with less than 2 turns. |
| 96 | dataset = [data for data in dataset if len(data['conversations']) >= 2] |
| 97 | # Only keep the first two turns of each conversation. |
| 98 | dataset = [(data['conversations'][0]['value'], data['conversations'][1]['value']) for data in dataset] |
| 99 | # remove the empty prompt |
| 100 | dataset = [(query, answer) for query, answer in dataset if len(query) > 0] |
| 101 | |
| 102 | # Shuffle the dataset. |
| 103 | random.shuffle(dataset) |
| 104 | |
| 105 | # Filter out sequences that are too long or too short |
| 106 | input_requests: list[tuple[str, int, int]] = [] |
| 107 | for i in range(num_prompts): |
| 108 | # Tokenize the prompts and completions. |
| 109 | prompt = dataset[i][0] |
| 110 | prompt_token_ids = tokenizer.encode(prompt) |
| 111 | prompt_len = len(prompt_token_ids) |
| 112 | |
| 113 | if prompt_len > input_lens[i]: |
| 114 | input_ids = prompt_token_ids[:input_lens[i]] |
| 115 | else: |
| 116 | ratio = (input_lens[i] + prompt_len - 1) // prompt_len |
| 117 | input_ids = (prompt_token_ids * ratio)[:input_lens[i]] |
| 118 | prompt = tokenizer.decode(input_ids) |
| 119 | input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) |
| 120 | else: |
| 121 | # Sample token ids from random integers. |
| 122 | # This can cause some NaN issues. |
| 123 | offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) |
| 124 | input_requests = [] |
| 125 | for i in range(num_prompts): |