(model, checkpoint_path, args, task_tokens=None)
| 14 | |
| 15 | |
| 16 | def load_pretrained(model, checkpoint_path, args, task_tokens=None): |
| 17 | load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path) |
| 18 | checkpoint_name = get_checkpoint_name(load_dir, tag, release) |
| 19 | if mpu.get_data_parallel_rank() == 0: |
| 20 | print('global rank {} is loading pretrained model {}'.format( |
| 21 | torch.distributed.get_rank(), checkpoint_name)) |
| 22 | # Load the checkpoint. |
| 23 | sd = torch.load(checkpoint_name, map_location='cpu') |
| 24 | if args.deepspeed: |
| 25 | model = model.module |
| 26 | if isinstance(model, TorchDDP): |
| 27 | model = model.module |
| 28 | if isinstance(model, FP16_Module): |
| 29 | model = model.module |
| 30 | if hasattr(model, "model"): |
| 31 | model = model.model |
| 32 | |
| 33 | # Model. |
| 34 | def extend_embedding_weights(state_weights, model_weights): |
| 35 | original_length = state_weights.shape[0] |
| 36 | assert original_length <= args.max_position_embeddings + 1 |
| 37 | new_weights = model_weights.clone() |
| 38 | new_weights[:original_length] = state_weights |
| 39 | return new_weights |
| 40 | |
| 41 | if args.block_lm: |
| 42 | if "transformer.block_position_embeddings.weight" in sd["module"]: |
| 43 | position_weights = sd['module']["transformer.position_embeddings.weight"] |
| 44 | if args.max_position_embeddings + 1 > position_weights.shape[0]: |
| 45 | sd['module']["transformer.position_embeddings.weight"] = extend_embedding_weights( |
| 46 | position_weights, model.state_dict()["transformer.position_embeddings.weight"].data) |
| 47 | print_rank_0(f"Extend position embedding to {args.max_position_embeddings + 1}") |
| 48 | if "transformer.block_position_embeddings.weight" in sd["module"]: |
| 49 | block_position_weights = sd['module']["transformer.block_position_embeddings.weight"] |
| 50 | if args.max_position_embeddings + 1 > block_position_weights.shape[0]: |
| 51 | sd['module']["transformer.block_position_embeddings.weight"] = extend_embedding_weights( |
| 52 | block_position_weights, |
| 53 | model.state_dict()["transformer.block_position_embeddings.weight"].data) |
| 54 | print_rank_0(f"Extend block position embedding to {args.max_position_embeddings + 1}") |
| 55 | missing_keys, unexpected_keys = model.load_state_dict(sd['module'], strict=False) |
| 56 | if missing_keys or unexpected_keys: |
| 57 | print_rank_0(f"Missing keys {missing_keys}, unexpected keys {unexpected_keys}") |
| 58 | if args.continuous_prompt and args.prompt_init: |
| 59 | model.prompt_spell.init_embedding(model.word_embeddings.weight.data, task_tokens) |
| 60 | |
| 61 | |
| 62 | def get_model(args, model_type=None, multi_token=True, num_labels=None, spell_length=None): |
no test coverage detected