MCPcopy
hub / github.com/HIT-SCIR/ltp / _from_pretrained

Method _from_pretrained

python/interface/ltp/nerual.py:463–517  ·  view source on GitHub ↗

Overwrite this method in case you wish to initialize your model in a different way.

(
        cls,
        model_id,
        revision,
        cache_dir,
        force_download,
        proxies,
        resume_download,
        local_files_only,
        use_auth_token,
        map_location="cpu",
        strict=False,
        **model_kwargs,
    )

Source from the content-addressed store, hash-verified

461
462 @classmethod
463 def _from_pretrained(
464 cls,
465 model_id,
466 revision,
467 cache_dir,
468 force_download,
469 proxies,
470 resume_download,
471 local_files_only,
472 use_auth_token,
473 map_location="cpu",
474 strict=False,
475 **model_kwargs,
476 ):
477 """Overwrite this method in case you wish to initialize your model in a different way."""
478 map_location = torch.device(map_location)
479 ltp = cls(**model_kwargs).to(map_location)
480
481 if os.path.isdir(model_id):
482 print("Loading weights from local directory")
483 model_file = os.path.join(model_id, PYTORCH_WEIGHTS_NAME)
484 tokenizer = AutoTokenizer.from_pretrained(
485 model_id, config=ltp.model.backbone.config, use_fast=True
486 )
487 else:
488 model_file = cls.download(
489 repo_id=model_id,
490 filename=PYTORCH_WEIGHTS_NAME,
491 revision=revision,
492 cache_dir=cache_dir,
493 force_download=force_download,
494 proxies=proxies,
495 resume_download=resume_download,
496 use_auth_token=use_auth_token,
497 local_files_only=local_files_only,
498 )
499 tokenizer = AutoTokenizer.from_pretrained(
500 pretrained_model_name_or_path=model_id,
501 config=ltp.model.backbone.config,
502 revision=revision,
503 cache_dir=cache_dir,
504 force_download=force_download,
505 proxies=proxies,
506 resume_download=resume_download,
507 use_auth_token=use_auth_token,
508 local_files_only=local_files_only,
509 use_fast=True,
510 )
511
512 ltp.tokenizer = tokenizer
513 state_dict = torch.load(model_file, map_location=map_location)
514 ltp.load_state_dict(state_dict, strict=strict)
515 ltp.eval()
516
517 return ltp
518
519
520def main():

Callers

nothing calls this directly

Calls 5

deviceMethod · 0.80
toMethod · 0.80
from_pretrainedMethod · 0.80
downloadMethod · 0.80
load_state_dictMethod · 0.45

Tested by

no test coverage detected