MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX2 / server

Function server

evaluation/generation.py:315–391  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

313
314
315def server(args):
316 logger.info(f"[ server ] starting ...")
317 entries = read_dataset(args.data_path, dataset_type=args.dataset_type)
318
319 assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by batch_size"
320
321 for entry in entries.values():
322 entry["prompt"] = process_extra_prompt(
323 entry["prompt"],
324 language_type=args.language_type,
325 dataset_type=args.dataset_type,
326 generation_mode=args.generation_mode,
327 )
328
329 res = []
330 for entry in entries.values():
331 res.extend([entry] * (args.samples_per_problem // args.micro_batch_size))
332 random.shuffle(res)
333 all_entries = res
334
335 # setup zeromq channel
336 logger.info(f"[ server ] starting up on port {args.channel_port}")
337 context = zmq.Context()
338 logger.info(f"[ server ] creating socket")
339 socket = context.socket(zmq.REP)
340 logger.info(f"[ server ] binding to port {args.channel_port}")
341 socket.bind(f"tcp://*:{args.channel_port}")
342
343 logger.info(
344 f"[ server ] loaded {len(entries)} entries, generating {len(entries) * args.samples_per_problem} samples",
345 )
346
347 remaining_entries = all_entries.copy()
348 running_workers = args.gen_node_world_size * torch.cuda.device_count()
349 num_finished = 0
350
351 logger.info(f"[ server ] listening for requests ...")
352 start_time = time.perf_counter()
353 while True:
354 # Wait for next request from client
355 msg = socket.recv_json()
356 rank = msg["rank"]
357 action = msg["action"]
358
359 if action == "pull":
360 if len(remaining_entries) == 0:
361 socket.send_json({"task_id": None})
362 running_workers -= 1
363 logger.info(f"[ server ] Shutting down worker {rank}, remaining {running_workers} workers")
364 if running_workers == 0 and num_finished == len(all_entries):
365 logger.info(f"[ server ] All workers finished")
366 break
367 else:
368 entry = remaining_entries.pop()
369 time_elapsed = time.perf_counter() - start_time
370 logger.info(f"[ server ] Sending entry {entry['task_id']} to worker {rank}")
371 remaining = (
372 len(remaining_entries)

Callers

nothing calls this directly

Calls 3

read_datasetFunction · 0.90
process_extra_promptFunction · 0.90
infoMethod · 0.80

Tested by

no test coverage detected