Overwrite this method in case you wish to initialize your model in a different way.
(
cls,
model_id,
revision,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
map_location="cpu",
strict=False,
**model_kwargs,
)
| 461 | |
| 462 | @classmethod |
| 463 | def _from_pretrained( |
| 464 | cls, |
| 465 | model_id, |
| 466 | revision, |
| 467 | cache_dir, |
| 468 | force_download, |
| 469 | proxies, |
| 470 | resume_download, |
| 471 | local_files_only, |
| 472 | use_auth_token, |
| 473 | map_location="cpu", |
| 474 | strict=False, |
| 475 | **model_kwargs, |
| 476 | ): |
| 477 | """Overwrite this method in case you wish to initialize your model in a different way.""" |
| 478 | map_location = torch.device(map_location) |
| 479 | ltp = cls(**model_kwargs).to(map_location) |
| 480 | |
| 481 | if os.path.isdir(model_id): |
| 482 | print("Loading weights from local directory") |
| 483 | model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME) |
| 484 | tokenizer = AutoTokenizer.from_pretrained( |
| 485 | model_id, config=ltp.model.backbone.config, use_fast=True |
| 486 | ) |
| 487 | else: |
| 488 | model_file = cls.download( |
| 489 | repo_id=model_id, |
| 490 | filename=PYTORCH_WEIGHTS_NAME, |
| 491 | revision=revision, |
| 492 | cache_dir=cache_dir, |
| 493 | force_download=force_download, |
| 494 | proxies=proxies, |
| 495 | resume_download=resume_download, |
| 496 | use_auth_token=use_auth_token, |
| 497 | local_files_only=local_files_only, |
| 498 | ) |
| 499 | tokenizer = AutoTokenizer.from_pretrained( |
| 500 | pretrained_model_name_or_path=model_id, |
| 501 | config=ltp.model.backbone.config, |
| 502 | revision=revision, |
| 503 | cache_dir=cache_dir, |
| 504 | force_download=force_download, |
| 505 | proxies=proxies, |
| 506 | resume_download=resume_download, |
| 507 | use_auth_token=use_auth_token, |
| 508 | local_files_only=local_files_only, |
| 509 | use_fast=True, |
| 510 | ) |
| 511 | |
| 512 | ltp.tokenizer = tokenizer |
| 513 | state_dict = torch.load(model_file, map_location=map_location) |
| 514 | ltp.load_state_dict(state_dict, strict=strict) |
| 515 | ltp.eval() |
| 516 | |
| 517 | return ltp |
| 518 | |
| 519 | |
| 520 | def main(): |
nothing calls this directly
no test coverage detected