MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / GPT2Model

Class GPT2Model

Megatron-LM/model/gpt2_modeling.py:35–105  ·  view source on GitHub ↗

GPT-2 Language model. The output of the forward method are the logits (parallel or serial depending on the `parallel_output` flag.

Source from the content-addressed store, hash-verified

33
34
35class GPT2Model(torch.nn.Module):
36 """GPT-2 Language model.
37
38 The output of the forward method are the logits (parallel or
39 serial depending on the `parallel_output` flag.
40 """
41
42 def __init__(self,
43 num_layers,
44 vocab_size,
45 hidden_size,
46 num_attention_heads,
47 embedding_dropout_prob,
48 attention_dropout_prob,
49 output_dropout_prob,
50 max_sequence_length,
51 checkpoint_activations,
52 checkpoint_num_layers=1,
53 parallel_output=True):
54
55 super(GPT2Model, self).__init__()
56
57 self.parallel_output = parallel_output
58
59 init_method = init_method_normal(std=0.02)
60
61 # Word embeddings (parallel).
62 self.word_embeddings = mpu.VocabParallelEmbedding(
63 vocab_size, hidden_size, init_method=init_method)
64
65 # Position embedding (serial).
66 self.position_embeddings = torch.nn.Embedding(max_sequence_length,
67 hidden_size)
68 # Initialize the position embeddings.
69 init_method(self.position_embeddings.weight)
70
71 # Embeddings dropout
72 self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
73
74 # Transformer
75 self.transformer = mpu.GPT2ParallelTransformer(num_layers,
76 hidden_size,
77 num_attention_heads,
78 attention_dropout_prob,
79 output_dropout_prob,
80 checkpoint_activations,
81 checkpoint_num_layers)
82
83 def forward(self, input_ids, position_ids, attention_mask):
84
85 # Embeddings.
86 words_embeddings = self.word_embeddings(input_ids)
87 position_embeddings = self.position_embeddings(position_ids)
88 embeddings = words_embeddings + position_embeddings
89
90 # Dropout.
91 embeddings = self.embedding_dropout(embeddings)
92

Callers 3

get_modelFunction · 0.90
get_modelFunction · 0.90
get_modelFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected