MCPcopy
hub / github.com/bojone/bert4keras / build_transformer_model

Function build_transformer_model

bert4keras/models.py:2640–2720  ·  view source on GitHub ↗

根据配置文件构建模型,可选加载checkpoint权重

(
    config_path=None,
    checkpoint_path=None,
    model='bert',
    application='encoder',
    return_keras_model=True,
    **kwargs
)

Source from the content-addressed store, hash-verified

2638
2639
2640def build_transformer_model(
2641 config_path=None,
2642 checkpoint_path=None,
2643 model='bert',
2644 application='encoder',
2645 return_keras_model=True,
2646 **kwargs
2647):
2648 """根据配置文件构建模型,可选加载checkpoint权重
2649 """
2650 configs = {}
2651 if config_path is not None:
2652 configs.update(json.load(open(config_path)))
2653 configs.update(kwargs)
2654 if 'max_position' not in configs:
2655 configs['max_position'] = configs.get('max_position_embeddings', 512)
2656 if 'dropout_rate' not in configs:
2657 configs['dropout_rate'] = configs.get('hidden_dropout_prob')
2658 if 'attention_dropout_rate' not in configs:
2659 configs['attention_dropout_rate'] = configs.get(
2660 'attention_probs_dropout_prob'
2661 )
2662 if 'segment_vocab_size' not in configs:
2663 configs['segment_vocab_size'] = configs.get('type_vocab_size', 2)
2664
2665 models = {
2666 'bert': BERT,
2667 'albert': ALBERT,
2668 'albert_unshared': ALBERT_Unshared,
2669 'roberta': BERT,
2670 'nezha': NEZHA,
2671 'roformer': RoFormer,
2672 'roformer_v2': RoFormerV2,
2673 'electra': ELECTRA,
2674 'gpt': GPT,
2675 'gpt2': GPT2,
2676 'gpt2_ml': GPT2_ML,
2677 't5': T5,
2678 't5_encoder': T5_Encoder,
2679 't5_decoder': T5_Decoder,
2680 't5.1.0': T5,
2681 't5.1.0_encoder': T5_Encoder,
2682 't5.1.0_decoder': T5_Decoder,
2683 't5.1.1': T5,
2684 't5.1.1_encoder': T5_Encoder,
2685 't5.1.1_decoder': T5_Decoder,
2686 'mt5.1.1': T5,
2687 'mt5.1.1_encoder': T5_Encoder,
2688 'mt5.1.1_decoder': T5_Decoder,
2689 }
2690
2691 if is_string(model):
2692 model = model.lower()
2693 MODEL = models[model]
2694 if model.endswith('t5.1.1'):
2695 configs['version'] = model
2696 else:
2697 MODEL = model

Calls 6

is_stringFunction · 0.90
openClass · 0.85
buildMethod · 0.45

Tested by

no test coverage detected