GPT-2 Language model. The output of the forward method are the logits (parallel or serial depending on the `parallel_output` flag.
| 33 | |
| 34 | |
| 35 | class GPT2Model(torch.nn.Module): |
| 36 | """GPT-2 Language model. |
| 37 | |
| 38 | The output of the forward method are the logits (parallel or |
| 39 | serial depending on the `parallel_output` flag. |
| 40 | """ |
| 41 | |
| 42 | def __init__(self, |
| 43 | num_layers, |
| 44 | vocab_size, |
| 45 | hidden_size, |
| 46 | num_attention_heads, |
| 47 | embedding_dropout_prob, |
| 48 | attention_dropout_prob, |
| 49 | output_dropout_prob, |
| 50 | max_sequence_length, |
| 51 | checkpoint_activations, |
| 52 | checkpoint_num_layers=1, |
| 53 | parallel_output=True): |
| 54 | |
| 55 | super(GPT2Model, self).__init__() |
| 56 | |
| 57 | self.parallel_output = parallel_output |
| 58 | |
| 59 | init_method = init_method_normal(std=0.02) |
| 60 | |
| 61 | # Word embeddings (parallel). |
| 62 | self.word_embeddings = mpu.VocabParallelEmbedding( |
| 63 | vocab_size, hidden_size, init_method=init_method) |
| 64 | |
| 65 | # Position embedding (serial). |
| 66 | self.position_embeddings = torch.nn.Embedding(max_sequence_length, |
| 67 | hidden_size) |
| 68 | # Initialize the position embeddings. |
| 69 | init_method(self.position_embeddings.weight) |
| 70 | |
| 71 | # Embeddings dropout |
| 72 | self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) |
| 73 | |
| 74 | # Transformer |
| 75 | self.transformer = mpu.GPT2ParallelTransformer(num_layers, |
| 76 | hidden_size, |
| 77 | num_attention_heads, |
| 78 | attention_dropout_prob, |
| 79 | output_dropout_prob, |
| 80 | checkpoint_activations, |
| 81 | checkpoint_num_layers) |
| 82 | |
| 83 | def forward(self, input_ids, position_ids, attention_mask): |
| 84 | |
| 85 | # Embeddings. |
| 86 | words_embeddings = self.word_embeddings(input_ids) |
| 87 | position_embeddings = self.position_embeddings(position_ids) |
| 88 | embeddings = words_embeddings + position_embeddings |
| 89 | |
| 90 | # Dropout. |
| 91 | embeddings = self.embedding_dropout(embeddings) |
| 92 |