Initialize a pretrained GPT model by copying over the weights from a huggingface/transformers checkpoint.
(cls, model_type)
| 173 | |
| 174 | @classmethod |
| 175 | def from_pretrained(cls, model_type): |
| 176 | """ |
| 177 | Initialize a pretrained GPT model by copying over the weights |
| 178 | from a huggingface/transformers checkpoint. |
| 179 | """ |
| 180 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} |
| 181 | from transformers import GPT2LMHeadModel |
| 182 | |
| 183 | # create a from-scratch initialized minGPT model |
| 184 | config = cls.get_default_config() |
| 185 | config.model_type = model_type |
| 186 | config.vocab_size = 50257 # openai's model vocabulary |
| 187 | config.block_size = 1024 # openai's model block_size |
| 188 | model = GPT(config) |
| 189 | sd = model.state_dict() |
| 190 | |
| 191 | # init a huggingface/transformers model |
| 192 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) |
| 193 | sd_hf = model_hf.state_dict() |
| 194 | |
| 195 | # copy while ensuring all of the parameters are aligned and match in names and shapes |
| 196 | keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these |
| 197 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] |
| 198 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. |
| 199 | # this means that we have to transpose these weights when we import them |
| 200 | assert len(keys) == len(sd) |
| 201 | for k in keys: |
| 202 | if any(k.endswith(w) for w in transposed): |
| 203 | # special treatment for the Conv1D weights we need to transpose |
| 204 | assert sd_hf[k].shape[::-1] == sd[k].shape |
| 205 | with torch.no_grad(): |
| 206 | sd[k].copy_(sd_hf[k].t()) |
| 207 | else: |
| 208 | # vanilla copy over the other parameters |
| 209 | assert sd_hf[k].shape == sd[k].shape |
| 210 | with torch.no_grad(): |
| 211 | sd[k].copy_(sd_hf[k]) |
| 212 | |
| 213 | return model |
| 214 | |
| 215 | def configure_optimizers(self, train_config): |
| 216 | """ |