MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / run_test

Method run_test

scripts/eval_needle.py:209–292  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

207 return int(context_length)
208
209 def run_test(self):
210 fs = gcsfs.GCSFileSystem()
211 contexts = []
212 template = self.OURS_TEMPLATE
213
214 def _key_from_result(result):
215 return (result['context_length'], result['depth_percent'], result['seed'])
216
217 results = []
218 completed = set()
219 def exists(fname):
220 if fname.startswith('gs://'):
221 return fs.exists(fname)
222 else:
223 return os.path.exists(fname)
224 if exists(FLAGS.output_file):
225 with open_file(FLAGS.output_file, 'r') as f:
226 results = json.load(f)
227 completed = set([_key_from_result(result) for result in results])
228 print('completed', len(completed))
229
230 full_contexts = self.read_context_files(FLAGS.n_rounds)
231 full_tokens = [self.enc.encode(full_context) for full_context in tqdm(full_contexts)]
232
233 start = time.time()
234 for context_length in self.context_lengths:
235 trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in tqdm(full_tokens)]
236 max_input_length = self.compute_max_input_length(context_length)
237 contexts = []
238 for depth_percent in self.document_depth_percents:
239 for i in range(FLAGS.n_rounds):
240 if (int(context_length), float(depth_percent), i) in completed:
241 continue
242 random_city = random.choice(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES)
243 insert_needle = True
244 needle_rnd_number = str(self.generate_random_number(self.rnd_number_digits))
245 print("context length: " + str(context_length))
246 print("depth_percent : " + str(depth_percent))
247 context = self.create_contexts(needle_rnd_number, insert_needle, random_city, trim_contexts[i], context_length, depth_percent, i)
248 contexts.append(context)
249
250 if len(contexts) == 0:
251 continue
252
253 B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)
254 B = int(B / self.model.data_dim) * self.model.data_dim
255 if B < self.model.data_dim:
256 B = self.model.data_dim
257 elif B > len(contexts):
258 B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)
259 if len(contexts) % B == 0:
260 n_pad = 0
261 else:
262 n_pad = B - len(contexts) % B
263 for _ in range(n_pad):
264 contexts.insert(0, contexts[0])
265
266 pbar = tqdm(total=len(contexts))

Callers 1

start_testMethod · 0.95

Calls 6

read_context_filesMethod · 0.95
create_contextsMethod · 0.95
encodeMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected