MCPcopy
hub / github.com/hpcaitech/ColossalAI / BertForPretrain

Class BertForPretrain

examples/tutorial/sequence_parallel/model/bert.py:18–126  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

16
17
18class BertForPretrain(nn.Module):
19 def __init__(
20 self,
21 vocab_size,
22 hidden_size,
23 max_sequence_length,
24 num_attention_heads,
25 num_layers,
26 add_binary_head,
27 is_naive_fp16,
28 num_tokentypes=2,
29 dropout_prob=0.1,
30 mlp_ratio=4,
31 init_std=0.02,
32 convert_fp16_to_fp32_in_softmax=False,
33 ):
34 super().__init__()
35 self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
36 assert (
37 max_sequence_length % self.seq_parallel_size == 0
38 ), "sequence length is not divisible by the sequence parallel size"
39 self.sub_seq_length = max_sequence_length // self.seq_parallel_size
40 self.init_std = init_std
41 self.num_layers = num_layers
42
43 if not add_binary_head:
44 num_tokentypes = 0
45
46 self.preprocessor = PreProcessor(self.sub_seq_length)
47 self.embedding = Embedding(
48 hidden_size=hidden_size,
49 vocab_size=vocab_size,
50 max_sequence_length=max_sequence_length,
51 embedding_dropout_prob=dropout_prob,
52 num_tokentypes=num_tokentypes,
53 )
54 self.bert_layers = nn.ModuleList()
55
56 for i in range(num_layers):
57 bert_layer = BertLayer(
58 layer_number=i + 1,
59 hidden_size=hidden_size,
60 num_attention_heads=num_attention_heads,
61 attention_dropout=dropout_prob,
62 mlp_ratio=mlp_ratio,
63 hidden_dropout=dropout_prob,
64 convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
65 is_naive_fp16=is_naive_fp16,
66 )
67 self.bert_layers.append(bert_layer)
68
69 self.layer_norm = LayerNorm(hidden_size)
70 self.head = BertDualHead(
71 hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head
72 )
73 self.reset_parameters()
74
75 def _init_normal(self, tensor):

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…