(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
conv_template: str = None,
multimodal: bool = False,
)
| 26 | |
| 27 | class BaseModelWorker: |
| 28 | def __init__( |
| 29 | self, |
| 30 | controller_addr: str, |
| 31 | worker_addr: str, |
| 32 | worker_id: str, |
| 33 | model_path: str, |
| 34 | model_names: List[str], |
| 35 | limit_worker_concurrency: int, |
| 36 | conv_template: str = None, |
| 37 | multimodal: bool = False, |
| 38 | ): |
| 39 | global logger, worker |
| 40 | |
| 41 | self.controller_addr = controller_addr |
| 42 | self.worker_addr = worker_addr |
| 43 | self.worker_id = worker_id |
| 44 | if model_path.endswith("/"): |
| 45 | model_path = model_path[:-1] |
| 46 | self.model_names = model_names or [model_path.split("/")[-1]] |
| 47 | self.limit_worker_concurrency = limit_worker_concurrency |
| 48 | self.conv = self.make_conv_template(conv_template, model_path) |
| 49 | self.conv.sep_style = int(self.conv.sep_style) |
| 50 | self.multimodal = multimodal |
| 51 | self.tokenizer = None |
| 52 | self.context_len = None |
| 53 | self.call_ct = 0 |
| 54 | self.semaphore = None |
| 55 | |
| 56 | self.heart_beat_thread = None |
| 57 | |
| 58 | if logger is None: |
| 59 | logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log") |
| 60 | if worker is None: |
| 61 | worker = self |
| 62 | |
| 63 | def make_conv_template( |
| 64 | self, |
nothing calls this directly
no test coverage detected