()
| 155 | |
| 156 | |
| 157 | def create_multi_model_worker(): |
| 158 | # Note: Ensure we resolve arg conflicts. We let `add_model_args` add MOST |
| 159 | # of the model args but we'll override one to have an append action that |
| 160 | # supports multiple values. |
| 161 | parser = argparse.ArgumentParser(conflict_handler="resolve") |
| 162 | parser.add_argument("--host", type=str, default="localhost") |
| 163 | parser.add_argument("--port", type=int, default=21002) |
| 164 | parser.add_argument("--worker-address", type=str, default="http://localhost:21002") |
| 165 | parser.add_argument( |
| 166 | "--controller-address", type=str, default="http://localhost:21001" |
| 167 | ) |
| 168 | add_model_args(parser) |
| 169 | # Override the model path to be repeated and align it with model names. |
| 170 | parser.add_argument( |
| 171 | "--model-path", |
| 172 | type=str, |
| 173 | default=[], |
| 174 | action="append", |
| 175 | help="One or more paths to model weights to load. This can be a local folder or a Hugging Face repo ID.", |
| 176 | ) |
| 177 | parser.add_argument( |
| 178 | "--model-names", |
| 179 | type=lambda s: s.split(","), |
| 180 | action="append", |
| 181 | help="One or more model names. Values must be aligned with `--model-path` values.", |
| 182 | ) |
| 183 | parser.add_argument( |
| 184 | "--conv-template", |
| 185 | type=str, |
| 186 | default=None, |
| 187 | action="append", |
| 188 | help="Conversation prompt template. Values must be aligned with `--model-path` values. If only one value is provided, it will be repeated for all models.", |
| 189 | ) |
| 190 | parser.add_argument("--limit-worker-concurrency", type=int, default=5) |
| 191 | parser.add_argument("--stream-interval", type=int, default=2) |
| 192 | parser.add_argument("--no-register", action="store_true") |
| 193 | parser.add_argument( |
| 194 | "--ssl", |
| 195 | action="store_true", |
| 196 | required=False, |
| 197 | default=False, |
| 198 | help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", |
| 199 | ) |
| 200 | args = parser.parse_args() |
| 201 | logger.info(f"args: {args}") |
| 202 | |
| 203 | if args.gpus: |
| 204 | if len(args.gpus.split(",")) < args.num_gpus: |
| 205 | raise ValueError( |
| 206 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" |
| 207 | ) |
| 208 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus |
| 209 | |
| 210 | gptq_config = GptqConfig( |
| 211 | ckpt=args.gptq_ckpt or args.model_path, |
| 212 | wbits=args.gptq_wbits, |
| 213 | groupsize=args.gptq_groupsize, |
| 214 | act_order=args.gptq_act_order, |
no test coverage detected
searching dependent graphs…