(self)
| 207 | return int(context_length) |
| 208 | |
| 209 | def run_test(self): |
| 210 | fs = gcsfs.GCSFileSystem() |
| 211 | contexts = [] |
| 212 | template = self.OURS_TEMPLATE |
| 213 | |
| 214 | def _key_from_result(result): |
| 215 | return (result['context_length'], result['depth_percent'], result['seed']) |
| 216 | |
| 217 | results = [] |
| 218 | completed = set() |
| 219 | def exists(fname): |
| 220 | if fname.startswith('gs://'): |
| 221 | return fs.exists(fname) |
| 222 | else: |
| 223 | return os.path.exists(fname) |
| 224 | if exists(FLAGS.output_file): |
| 225 | with open_file(FLAGS.output_file, 'r') as f: |
| 226 | results = json.load(f) |
| 227 | completed = set([_key_from_result(result) for result in results]) |
| 228 | print('completed', len(completed)) |
| 229 | |
| 230 | full_contexts = self.read_context_files(FLAGS.n_rounds) |
| 231 | full_tokens = [self.enc.encode(full_context) for full_context in tqdm(full_contexts)] |
| 232 | |
| 233 | start = time.time() |
| 234 | for context_length in self.context_lengths: |
| 235 | trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in tqdm(full_tokens)] |
| 236 | max_input_length = self.compute_max_input_length(context_length) |
| 237 | contexts = [] |
| 238 | for depth_percent in self.document_depth_percents: |
| 239 | for i in range(FLAGS.n_rounds): |
| 240 | if (int(context_length), float(depth_percent), i) in completed: |
| 241 | continue |
| 242 | random_city = random.choice(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES) |
| 243 | insert_needle = True |
| 244 | needle_rnd_number = str(self.generate_random_number(self.rnd_number_digits)) |
| 245 | print("context length: " + str(context_length)) |
| 246 | print("depth_percent : " + str(depth_percent)) |
| 247 | context = self.create_contexts(needle_rnd_number, insert_needle, random_city, trim_contexts[i], context_length, depth_percent, i) |
| 248 | contexts.append(context) |
| 249 | |
| 250 | if len(contexts) == 0: |
| 251 | continue |
| 252 | |
| 253 | B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size) |
| 254 | B = int(B / self.model.data_dim) * self.model.data_dim |
| 255 | if B < self.model.data_dim: |
| 256 | B = self.model.data_dim |
| 257 | elif B > len(contexts): |
| 258 | B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim) |
| 259 | if len(contexts) % B == 0: |
| 260 | n_pad = 0 |
| 261 | else: |
| 262 | n_pad = B - len(contexts) % B |
| 263 | for _ in range(n_pad): |
| 264 | contexts.insert(0, contexts[0]) |
| 265 | |
| 266 | pbar = tqdm(total=len(contexts)) |
no test coverage detected