MCPcopy Index your code
hub / github.com/LargeWorldModel/LWM / run_test

Method run_test

scripts/eval_needle_multi.py:214–301  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

212 return int(context_length)
213
214 def run_test(self):
215 fs = gcsfs.GCSFileSystem()
216 contexts = []
217 template = self.OURS_TEMPLATE
218
219 def _key_from_result(result):
220 return (result['context_length'], result['depth_percent'], result['seed'])
221
222 results = []
223 completed = set()
224 def exists(fname):
225 if fname.startswith('gs://'):
226 return fs.exists(fname)
227 else:
228 return os.path.exists(fname)
229 if exists(FLAGS.output_file):
230 with open_file(FLAGS.output_file, 'r') as f:
231 results = json.load(f)
232 completed = set([_key_from_result(result) for result in results])
233 print('completed', len(completed))
234
235 full_contexts = self.read_context_files(FLAGS.n_rounds)
236 full_tokens = [self.enc.encode(full_context) for full_context in full_contexts]
237
238 start = time.time()
239 for context_length in self.context_lengths:
240 trim_contexts = [self.enc.decode(full_token[:context_length]) for full_token in full_tokens]
241 max_input_length = self.compute_max_input_length(context_length)
242 contexts = []
243 for i in range(FLAGS.n_rounds):
244 if (int(context_length), i) in completed:
245 continue
246 random_cities = random.sample(LLMNeedleHaystackTester.RANDOM_NEEDLE_CITIES, FLAGS.n_needles_total)
247 document_depths = random.sample(self.document_depth_percents, FLAGS.n_needles_total)
248 random_cities_retrieve = random.sample(random_cities, FLAGS.n_needles_retrieve)
249 needles_info = {}
250 for random_city, depth_percent in zip(random_cities, document_depths):
251 needles_info[random_city] = (
252 str(self.generate_random_number(self.rnd_number_digits)),
253 depth_percent
254 )
255 context = self.create_contexts(needles_info, random_cities_retrieve, trim_contexts[i], context_length, i)
256 contexts.append(context)
257
258 if len(contexts) == 0:
259 continue
260
261 B = FLAGS.max_tokens_per_batch / (max_input_length + self.model.block_size)
262 B = int(B / self.model.data_dim) * self.model.data_dim
263 if B < self.model.data_dim:
264 B = self.model.data_dim
265 elif B > len(contexts):
266 B = int(math.ceil(len(contexts) / self.model.data_dim) * self.model.data_dim)
267 n_pad = B - len(contexts) % B
268 for _ in range(n_pad):
269 contexts.insert(0, contexts[0])
270
271 pbar = tqdm(total=len(contexts))

Callers 1

start_testMethod · 0.95

Calls 6

read_context_filesMethod · 0.95
create_contextsMethod · 0.95
encodeMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected