(self, needles_info, random_cities_retrieve, context, context_length, seed)
| 141 | return context |
| 142 | |
| 143 | def create_contexts(self, needles_info, random_cities_retrieve, context, context_length, seed): |
| 144 | assert all([random_city in needles_info for random_city in random_cities_retrieve]) |
| 145 | for random_city, (needle_rnd_number, depth_percent) in needles_info.items(): |
| 146 | context = self.generate_context( |
| 147 | self.needle.format(city=random_city, rnd_number=needle_rnd_number), |
| 148 | context, context_length, depth_percent |
| 149 | ) |
| 150 | |
| 151 | if len(random_cities_retrieve) == 1: |
| 152 | question = f"What is the special magic number for {random_cities_retrieve[0]}?" |
| 153 | else: |
| 154 | q = ', '.join(random_cities_retrieve[:-1]) + ', and ' + random_cities_retrieve[-1] |
| 155 | question = self.retrieval_question.format(q) |
| 156 | results = { |
| 157 | 'context' : context, |
| 158 | 'context_length' : int(context_length), |
| 159 | 'needles_info': needles_info, |
| 160 | 'question' : question, |
| 161 | 'cities_to_retrieve' : random_cities_retrieve, |
| 162 | 'seed': seed, |
| 163 | } |
| 164 | return results |
| 165 | |
| 166 | def insert_needle(self, needle, context, depth_percent, context_length): |
| 167 | tokens_needle = self.enc_tiktoken.encode(needle) |
no test coverage detected