根据配置文件构建模型,可选加载checkpoint权重
(
config_path=None,
checkpoint_path=None,
model='bert',
application='encoder',
return_keras_model=True,
**kwargs
)
| 2638 | |
| 2639 | |
| 2640 | def build_transformer_model( |
| 2641 | config_path=None, |
| 2642 | checkpoint_path=None, |
| 2643 | model='bert', |
| 2644 | application='encoder', |
| 2645 | return_keras_model=True, |
| 2646 | **kwargs |
| 2647 | ): |
| 2648 | """根据配置文件构建模型,可选加载checkpoint权重 |
| 2649 | """ |
| 2650 | configs = {} |
| 2651 | if config_path is not None: |
| 2652 | configs.update(json.load(open(config_path))) |
| 2653 | configs.update(kwargs) |
| 2654 | if 'max_position' not in configs: |
| 2655 | configs['max_position'] = configs.get('max_position_embeddings', 512) |
| 2656 | if 'dropout_rate' not in configs: |
| 2657 | configs['dropout_rate'] = configs.get('hidden_dropout_prob') |
| 2658 | if 'attention_dropout_rate' not in configs: |
| 2659 | configs['attention_dropout_rate'] = configs.get( |
| 2660 | 'attention_probs_dropout_prob' |
| 2661 | ) |
| 2662 | if 'segment_vocab_size' not in configs: |
| 2663 | configs['segment_vocab_size'] = configs.get('type_vocab_size', 2) |
| 2664 | |
| 2665 | models = { |
| 2666 | 'bert': BERT, |
| 2667 | 'albert': ALBERT, |
| 2668 | 'albert_unshared': ALBERT_Unshared, |
| 2669 | 'roberta': BERT, |
| 2670 | 'nezha': NEZHA, |
| 2671 | 'roformer': RoFormer, |
| 2672 | 'roformer_v2': RoFormerV2, |
| 2673 | 'electra': ELECTRA, |
| 2674 | 'gpt': GPT, |
| 2675 | 'gpt2': GPT2, |
| 2676 | 'gpt2_ml': GPT2_ML, |
| 2677 | 't5': T5, |
| 2678 | 't5_encoder': T5_Encoder, |
| 2679 | 't5_decoder': T5_Decoder, |
| 2680 | 't5.1.0': T5, |
| 2681 | 't5.1.0_encoder': T5_Encoder, |
| 2682 | 't5.1.0_decoder': T5_Decoder, |
| 2683 | 't5.1.1': T5, |
| 2684 | 't5.1.1_encoder': T5_Encoder, |
| 2685 | 't5.1.1_decoder': T5_Decoder, |
| 2686 | 'mt5.1.1': T5, |
| 2687 | 'mt5.1.1_encoder': T5_Encoder, |
| 2688 | 'mt5.1.1_decoder': T5_Decoder, |
| 2689 | } |
| 2690 | |
| 2691 | if is_string(model): |
| 2692 | model = model.lower() |
| 2693 | MODEL = models[model] |
| 2694 | if model.endswith('t5.1.1'): |
| 2695 | configs['version'] = model |
| 2696 | else: |
| 2697 | MODEL = model |
no test coverage detected