MCPcopy
hub / github.com/karpathy/minGPT / from_pretrained

Method from_pretrained

mingpt/model.py:175–213  ·  view source on GitHub ↗

Initialize a pretrained GPT model by copying over the weights from a huggingface/transformers checkpoint.

(cls, model_type)

Source from the content-addressed store, hash-verified

173
174 @classmethod
175 def from_pretrained(cls, model_type):
176 """
177 Initialize a pretrained GPT model by copying over the weights
178 from a huggingface/transformers checkpoint.
179 """
180 assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
181 from transformers import GPT2LMHeadModel
182
183 # create a from-scratch initialized minGPT model
184 config = cls.get_default_config()
185 config.model_type = model_type
186 config.vocab_size = 50257 # openai's model vocabulary
187 config.block_size = 1024 # openai's model block_size
188 model = GPT(config)
189 sd = model.state_dict()
190
191 # init a huggingface/transformers model
192 model_hf = GPT2LMHeadModel.from_pretrained(model_type)
193 sd_hf = model_hf.state_dict()
194
195 # copy while ensuring all of the parameters are aligned and match in names and shapes
196 keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these
197 transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
198 # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear.
199 # this means that we have to transpose these weights when we import them
200 assert len(keys) == len(sd)
201 for k in keys:
202 if any(k.endswith(w) for w in transposed):
203 # special treatment for the Conv1D weights we need to transpose
204 assert sd_hf[k].shape[::-1] == sd[k].shape
205 with torch.no_grad():
206 sd[k].copy_(sd_hf[k].t())
207 else:
208 # vanilla copy over the other parameters
209 assert sd_hf[k].shape == sd[k].shape
210 with torch.no_grad():
211 sd[k].copy_(sd_hf[k])
212
213 return model
214
215 def configure_optimizers(self, train_config):
216 """

Callers 1

test_gpt2Method · 0.80

Calls 2

GPTClass · 0.85
get_default_configMethod · 0.45

Tested by 1

test_gpt2Method · 0.64