MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / BertModel

Class BertModel

Megatron-LM/model/model.py:45–89  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

43
44
45class BertModel(torch.nn.Module):
46
47 def __init__(self, args):
48 super(BertModel, self).__init__()
49 if args.pretrained_bert:
50 self.model = BertForPreTraining.from_pretrained(
51 args.tokenizer_model_type,
52 cache_dir=args.cache_dir,
53 fp32_layernorm=args.fp32_layernorm,
54 fp32_embedding=args.fp32_embedding,
55 layernorm_epsilon=args.layernorm_epsilon)
56 else:
57 if args.intermediate_size is None:
58 intermediate_size = 4 * args.hidden_size
59 else:
60 intermediate_size = args.intermediate_size
61 self.config = BertConfig(
62 args.tokenizer_num_tokens,
63 hidden_size=args.hidden_size,
64 num_hidden_layers=args.num_layers,
65 num_attention_heads=args.num_attention_heads,
66 intermediate_size=intermediate_size,
67 hidden_dropout_prob=args.hidden_dropout,
68 attention_probs_dropout_prob=args.attention_dropout,
69 max_position_embeddings=args.max_position_embeddings,
70 type_vocab_size=args.tokenizer_num_type_tokens,
71 fp32_layernorm=args.fp32_layernorm,
72 fp32_embedding=args.fp32_embedding,
73 fp32_tokentypes=args.fp32_tokentypes,
74 layernorm_epsilon=args.layernorm_epsilon,
75 deep_init=args.deep_init)
76 self.model = BertForPreTraining(self.config)
77
78 def forward(self, input_tokens, token_type_ids=None,
79 attention_mask=None, checkpoint_activations=False):
80 return self.model(
81 input_tokens, token_type_ids, attention_mask,
82 checkpoint_activations=checkpoint_activations)
83
84 def state_dict(self, destination=None, prefix='', keep_vars=False):
85 return self.model.state_dict(destination=destination, prefix=prefix,
86 keep_vars=keep_vars)
87
88 def load_state_dict(self, state_dict, strict=True):
89 return self.model.load_state_dict(state_dict, strict=strict)
90

Callers 1

get_modelFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected