Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. Params: pretrained_model_name: either: - a str with the name of a pre-trained model to load selected
(cls, pretrained_model_name, state_dict=None, cache_dir=None,
fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12,
fp32_tokentypes=False, *inputs, **kwargs)
| 716 | |
| 717 | @classmethod |
| 718 | def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, |
| 719 | fp32_layernorm=False, fp32_embedding=False, layernorm_epsilon=1e-12, |
| 720 | fp32_tokentypes=False, *inputs, **kwargs): |
| 721 | """ |
| 722 | Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. |
| 723 | Download and cache the pre-trained model file if needed. |
| 724 | |
| 725 | Params: |
| 726 | pretrained_model_name: either: |
| 727 | - a str with the name of a pre-trained model to load selected in the list of: |
| 728 | . `bert-base-uncased` |
| 729 | . `bert-large-uncased` |
| 730 | . `bert-base-cased` |
| 731 | . `bert-large-cased` |
| 732 | . `bert-base-multilingual-uncased` |
| 733 | . `bert-base-multilingual-cased` |
| 734 | . `bert-base-chinese` |
| 735 | - a path or url to a pretrained model archive containing: |
| 736 | . `bert_config.json` a configuration file for the model |
| 737 | . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance |
| 738 | cache_dir: an optional path to a folder in which the pre-trained models will be cached. |
| 739 | state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models |
| 740 | *inputs, **kwargs: additional input for the specific Bert class |
| 741 | (ex: num_labels for BertForSequenceClassification) |
| 742 | """ |
| 743 | if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: |
| 744 | archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] |
| 745 | else: |
| 746 | archive_file = pretrained_model_name |
| 747 | # redirect to the cache, if necessary |
| 748 | try: |
| 749 | resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) |
| 750 | except FileNotFoundError: |
| 751 | logger.error( |
| 752 | "Model name '{}' was not found in model name list ({}). " |
| 753 | "We assumed '{}' was a path or url but couldn't find any file " |
| 754 | "associated to this path or url.".format( |
| 755 | pretrained_model_name, |
| 756 | ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), |
| 757 | archive_file)) |
| 758 | return None |
| 759 | if resolved_archive_file == archive_file: |
| 760 | logger.info("loading archive file {}".format(archive_file)) |
| 761 | else: |
| 762 | logger.info("loading archive file {} from cache at {}".format( |
| 763 | archive_file, resolved_archive_file)) |
| 764 | tempdir = None |
| 765 | if os.path.isdir(resolved_archive_file): |
| 766 | serialization_dir = resolved_archive_file |
| 767 | else: |
| 768 | # Extract archive to temp dir |
| 769 | tempdir = tempfile.mkdtemp() |
| 770 | logger.info("extracting archive file {} to temp dir {}".format( |
| 771 | resolved_archive_file, tempdir)) |
| 772 | with tarfile.open(resolved_archive_file, 'r:gz') as archive: |
| 773 | def is_within_directory(directory, target): |
| 774 | |
| 775 | abs_directory = os.path.abspath(directory) |
nothing calls this directly
no test coverage detected