(self)
| 212 | return int(context_length) |
| 213 | |
| 214 | def run_test(self): |
| 215 | fs = gcsfs.GCSFileSystem() |
| 216 | contexts = [] |
| 217 | template = self.OURS_TEMPLATE |
| 218 | |
| 219 | def _key_from_result(result): |
| 220 | return (result['context_length'], result['depth_percent'], result['seed']) |
| 221 | |
| 222 | results = [] |
| 223 | completed = set() |
| 224 | def exists(fname): |
| 225 | if fname.startswith('gs://'): |
| 226 | return fs.exists(fname) |
| 227 | else: |
| 228 | return os.path.exists(fname) |
| 229 | if exists(FLAGS.output_file): |
| 230 | with open_file(FLAGS.output_file, 'r') as f: |
| 231 | results = json.load(f) |
| 232 | completed = set([_key_from_result(result) for result in results]) |
| 233 | print('completed', len(completed)) |
| 234 | |
| 235 | full_contexts = self.read_context_files(FLAGS.n_rounds) |
| 236 | full_tokens = [self.enc.encode(full_context) for full_context in full_contexts] |
| 237 | |
| 238 | start = time.time() |
| 239 | for context_length in self.context_lengths: |
| 240 | trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in full_tokens] |
| 241 | max_input_length = self.compute_max_input_length(context_length) |
| 242 | contexts = [] |
| 243 | for i in range(FLAGS.n_rounds): |
| 244 | if (int(context_length), i) in completed: |
| 245 | continue |
| 246 | random_cities = random.sample(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES, FLAGS.n_needles_total) |
| 247 | document_depths = random.sample(self.document_depth_percents, FLAGS.n_needles_total) |
| 248 | random_cities_retrieve = random.sample(random_cities, FLAGS.n_needles_retrieve) |
| 249 | needles_info = {} |
| 250 | for random_city, depth_percent in zip(random_cities, document_depths): |
| 251 | needles_info[random_city] = ( |
| 252 | str(self.generate_random_number(self.rnd_number_digits)), |
| 253 | depth_percent |
| 254 | ) |
| 255 | context = self.create_contexts(needles_info, random_cities_retrieve, trim_contexts[i], context_length, i) |
| 256 | contexts.append(context) |
| 257 | |
| 258 | if len(contexts) == 0: |
| 259 | continue |
| 260 | |
| 261 | B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size) |
| 262 | B = int(B / self.model.data_dim) * self.model.data_dim |
| 263 | if B < self.model.data_dim: |
| 264 | B = self.model.data_dim |
| 265 | elif B > len(contexts): |
| 266 | B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim) |
| 267 | n_pad = B - len(contexts) % B |
| 268 | for _ in range(n_pad): |
| 269 | contexts.insert(0, contexts[0]) |
| 270 | |
| 271 | pbar = tqdm(total=len(contexts)) |
no test coverage detected