Download model from hub, build all components, and load pretrained weights. This method handles the full model construction pipeline: 1. Download model files from ModelScope/HuggingFace (if not local) 2. Parse config.yaml to determine model class, tokenizer, frontend
(**kwargs)
| 279 | |
| 280 | @staticmethod |
| 281 | def build_model(**kwargs): |
| 282 | """Download model from hub, build all components, and load pretrained weights. |
| 283 | |
| 284 | This method handles the full model construction pipeline: |
| 285 | 1. Download model files from ModelScope/HuggingFace (if not local) |
| 286 | 2. Parse config.yaml to determine model class, tokenizer, frontend |
| 287 | 3. Instantiate tokenizer, frontend, and model via the registry |
| 288 | 4. Load pretrained weights from model.pt |
| 289 | |
| 290 | Args: |
| 291 | **kwargs: Must include 'model' (str). All other config.yaml fields can be overridden. |
| 292 | |
| 293 | Returns: |
| 294 | tuple: (model, kwargs) where model is the instantiated nn.Module and |
| 295 | kwargs contains the resolved configuration. |
| 296 | """ |
| 297 | assert "model" in kwargs |
| 298 | if "model_conf" not in kwargs: |
| 299 | logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) |
| 300 | kwargs = download_model(**kwargs) |
| 301 | |
| 302 | set_all_random_seed(kwargs.get("seed", 0)) |
| 303 | |
| 304 | device = kwargs.get("device", "cuda") |
| 305 | if ( |
| 306 | (device.startswith("cuda") and not torch.cuda.is_available()) |
| 307 | or (device.startswith("xpu") and not torch.xpu.is_available()) |
| 308 | or (device.startswith("mps") and not torch.backends.mps.is_available()) |
| 309 | or (device.startswith("npu") and not is_npu_available()) |
| 310 | or kwargs.get("ngpu", 1) == 0 |
| 311 | ): |
| 312 | device = "cpu" |
| 313 | kwargs["batch_size"] = 1 |
| 314 | kwargs["device"] = device |
| 315 | |
| 316 | ncpu = _resolve_ncpu(kwargs, 4) |
| 317 | kwargs["ncpu"] = ncpu |
| 318 | if torch.get_num_threads() != ncpu: |
| 319 | torch.set_num_threads(ncpu) |
| 320 | |
| 321 | # build tokenizer |
| 322 | tokenizer = kwargs.get("tokenizer", None) |
| 323 | kwargs["tokenizer"] = tokenizer |
| 324 | kwargs["vocab_size"] = -1 |
| 325 | |
| 326 | if tokenizer is not None: |
| 327 | tokenizers = ( |
| 328 | tokenizer.split(",") if isinstance(tokenizer, str) else tokenizer |
| 329 | ) # type of tokenizers is list!!! |
| 330 | tokenizers_conf = kwargs.get("tokenizer_conf", {}) |
| 331 | tokenizers_build = [] |
| 332 | vocab_sizes = [] |
| 333 | token_lists = [] |
| 334 | |
| 335 | ### === only for kws === |
| 336 | token_list_files = kwargs.get("token_lists", []) |
| 337 | seg_dicts = kwargs.get("seg_dicts", []) |
| 338 | ### === only for kws === |