MCPcopy Index your code
hub / github.com/THUDM/GLM / load_pretrained

Function load_pretrained

train_utils.py:16–59  ·  view source on GitHub ↗
(model, checkpoint_path, args, task_tokens=None)

Source from the content-addressed store, hash-verified

14
15
16def load_pretrained(model, checkpoint_path, args, task_tokens=None):
17 load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path)
18 checkpoint_name = get_checkpoint_name(load_dir, tag, release)
19 if mpu.get_data_parallel_rank() == 0:
20 print('global rank {} is loading pretrained model {}'.format(
21 torch.distributed.get_rank(), checkpoint_name))
22 # Load the checkpoint.
23 sd = torch.load(checkpoint_name, map_location='cpu')
24 if args.deepspeed:
25 model = model.module
26 if isinstance(model, TorchDDP):
27 model = model.module
28 if isinstance(model, FP16_Module):
29 model = model.module
30 if hasattr(model, "model"):
31 model = model.model
32
33 # Model.
34 def extend_embedding_weights(state_weights, model_weights):
35 original_length = state_weights.shape[0]
36 assert original_length <= args.max_position_embeddings + 1
37 new_weights = model_weights.clone()
38 new_weights[:original_length] = state_weights
39 return new_weights
40
41 if args.block_lm:
42 if "transformer.block_position_embeddings.weight" in sd["module"]:
43 position_weights = sd['module']["transformer.position_embeddings.weight"]
44 if args.max_position_embeddings + 1 > position_weights.shape[0]:
45 sd['module']["transformer.position_embeddings.weight"] = extend_embedding_weights(
46 position_weights, model.state_dict()["transformer.position_embeddings.weight"].data)
47 print_rank_0(f"Extend position embedding to {args.max_position_embeddings + 1}")
48 if "transformer.block_position_embeddings.weight" in sd["module"]:
49 block_position_weights = sd['module']["transformer.block_position_embeddings.weight"]
50 if args.max_position_embeddings + 1 > block_position_weights.shape[0]:
51 sd['module']["transformer.block_position_embeddings.weight"] = extend_embedding_weights(
52 block_position_weights,
53 model.state_dict()["transformer.block_position_embeddings.weight"].data)
54 print_rank_0(f"Extend block position embedding to {args.max_position_embeddings + 1}")
55 missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False)
56 if missing_keys or unexpected_keys:
57 print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}")
58 if args.continuous_prompt and args.prompt_init:
59 model.prompt_spell.init_embedding(model.word_embeddings.weight.data, task_tokens)
60
61
62def get_model(args, model_type=None, multi_token=True, num_labels=None, spell_length=None):

Callers 1

finetuneFunction · 0.90

Calls 8

get_checkpoint_iterationFunction · 0.90
get_checkpoint_nameFunction · 0.90
print_rank_0Function · 0.90
extend_embedding_weightsFunction · 0.85
loadMethod · 0.80
init_embeddingMethod · 0.80
state_dictMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected