(args)
| 313 | |
| 314 | |
| 315 | def server(args): |
| 316 | logger.info(f"[ server ] starting ...") |
| 317 | entries = read_dataset(args.data_path, dataset_type=args.dataset_type) |
| 318 | |
| 319 | assert args.samples_per_problem % args.micro_batch_size == 0, "samples_per_problem should be divisible by batch_size" |
| 320 | |
| 321 | for entry in entries.values(): |
| 322 | entry["prompt"] = process_extra_prompt( |
| 323 | entry["prompt"], |
| 324 | language_type=args.language_type, |
| 325 | dataset_type=args.dataset_type, |
| 326 | generation_mode=args.generation_mode, |
| 327 | ) |
| 328 | |
| 329 | res = [] |
| 330 | for entry in entries.values(): |
| 331 | res.extend([entry] * (args.samples_per_problem // args.micro_batch_size)) |
| 332 | random.shuffle(res) |
| 333 | all_entries = res |
| 334 | |
| 335 | # setup zeromq channel |
| 336 | logger.info(f"[ server ] starting up on port {args.channel_port}") |
| 337 | context = zmq.Context() |
| 338 | logger.info(f"[ server ] creating socket") |
| 339 | socket = context.socket(zmq.REP) |
| 340 | logger.info(f"[ server ] binding to port {args.channel_port}") |
| 341 | socket.bind(f"tcp://*:{args.channel_port}") |
| 342 | |
| 343 | logger.info( |
| 344 | f"[ server ] loaded {len(entries)} entries, generating {len(entries) * args.samples_per_problem} samples", |
| 345 | ) |
| 346 | |
| 347 | remaining_entries = all_entries.copy() |
| 348 | running_workers = args.gen_node_world_size * torch.cuda.device_count() |
| 349 | num_finished = 0 |
| 350 | |
| 351 | logger.info(f"[ server ] listening for requests ...") |
| 352 | start_time = time.perf_counter() |
| 353 | while True: |
| 354 | # Wait for next request from client |
| 355 | msg = socket.recv_json() |
| 356 | rank = msg["rank"] |
| 357 | action = msg["action"] |
| 358 | |
| 359 | if action == "pull": |
| 360 | if len(remaining_entries) == 0: |
| 361 | socket.send_json({"task_id": None}) |
| 362 | running_workers -= 1 |
| 363 | logger.info(f"[ server ] Shutting down worker {rank}, remaining {running_workers} workers") |
| 364 | if running_workers == 0 and num_finished == len(all_entries): |
| 365 | logger.info(f"[ server ] All workers finished") |
| 366 | break |
| 367 | else: |
| 368 | entry = remaining_entries.pop() |
| 369 | time_elapsed = time.perf_counter() - start_time |
| 370 | logger.info(f"[ server ] Sending entry {entry['task_id']} to worker {rank}") |
| 371 | remaining = ( |
| 372 | len(remaining_entries) |
nothing calls this directly
no test coverage detected