Main loop.
(
self,
proc_id: int,
comm_notifier: Any,
comm_buf_name: str,
ret_notifier: Any,
ret_buf_name: str,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
backend_config: BackendConfig,
dist_config: DistConfig,
misc_config: MiscConfig,
specdecode_config: SpecDecodeConfig = None,
adapters: dict[str, str] = None,
device_type: str = 'cuda',
log_level: int = 30,
trust_remote_code: bool = False,
)
| 497 | self._proc.join() |
| 498 | |
| 499 | def _main_loop( |
| 500 | self, |
| 501 | proc_id: int, |
| 502 | comm_notifier: Any, |
| 503 | comm_buf_name: str, |
| 504 | ret_notifier: Any, |
| 505 | ret_buf_name: str, |
| 506 | model_path: str, |
| 507 | model_config: ModelConfig, |
| 508 | cache_config: CacheConfig, |
| 509 | backend_config: BackendConfig, |
| 510 | dist_config: DistConfig, |
| 511 | misc_config: MiscConfig, |
| 512 | specdecode_config: SpecDecodeConfig = None, |
| 513 | adapters: dict[str, str] = None, |
| 514 | device_type: str = 'cuda', |
| 515 | log_level: int = 30, |
| 516 | trust_remote_code: bool = False, |
| 517 | ): |
| 518 | """Main loop.""" |
| 519 | init_backend(device_type) |
| 520 | torch.cuda.set_device(proc_id) |
| 521 | |
| 522 | # catch signal |
| 523 | def handle_sigterm(signum, frame): |
| 524 | logger.debug(f'Proc[{proc_id}] terminated.') |
| 525 | exit(0) |
| 526 | |
| 527 | signal.signal(signal.SIGTERM, handle_sigterm) |
| 528 | |
| 529 | worker = MPWorkerWrapper(model_path, |
| 530 | cache_config=cache_config, |
| 531 | backend_config=backend_config, |
| 532 | model_config=model_config, |
| 533 | dist_config=dist_config, |
| 534 | misc_config=misc_config, |
| 535 | specdecode_config=specdecode_config, |
| 536 | adapters=adapters, |
| 537 | device_type=device_type, |
| 538 | log_level=log_level, |
| 539 | trust_remote_code=trust_remote_code) |
| 540 | try_import_deeplink(device_type) |
| 541 | worker.init_process_group(proc_id) |
| 542 | comm_buf = SharedBuffer(proc_id, notifier=comm_notifier, name=comm_buf_name) |
| 543 | ret_buf = SharedBuffer(-1, notifier=ret_notifier, name=ret_buf_name) |
| 544 | event_loop = asyncio.new_event_loop() |
| 545 | asyncio.set_event_loop(event_loop) |
| 546 | destroy_pg = worker.world_size > 1 |
| 547 | try: |
| 548 | event_loop.run_until_complete( |
| 549 | self._main_loop_impl(proc_id, comm_buf=comm_buf, ret_buf=ret_buf, worker=worker)) |
| 550 | except asyncio.CancelledError: |
| 551 | logger.warning(f'Proc[{proc_id}] main loop cancelled.') |
| 552 | destroy_pg = False |
| 553 | os.kill(os.getppid(), signal.SIGUSR1) |
| 554 | except SystemExit: |
| 555 | # terminated by executor |
| 556 | logger.debug(f'Proc[{proc_id}] system exit.') |
nothing calls this directly
no test coverage detected