(model, bkey)
| 23 | |
| 24 | ################## |
| 25 | def go(model, bkey): |
| 26 | saved_state_dict = checkpoint_dict[bkey] |
| 27 | if hasattr(model, "module"): |
| 28 | state_dict = model.module.state_dict() |
| 29 | else: |
| 30 | state_dict = model.state_dict() |
| 31 | new_state_dict = {} |
| 32 | for k, v in state_dict.items(): # 模型需要的shape |
| 33 | try: |
| 34 | new_state_dict[k] = saved_state_dict[k] |
| 35 | if saved_state_dict[k].shape != state_dict[k].shape: |
| 36 | logger.warning( |
| 37 | "shape-%s-mismatch. need: %s, get: %s", |
| 38 | k, |
| 39 | state_dict[k].shape, |
| 40 | saved_state_dict[k].shape, |
| 41 | ) # |
| 42 | raise KeyError |
| 43 | except: |
| 44 | # logger.info(traceback.format_exc()) |
| 45 | logger.info("%s is not in the checkpoint", k) # pretrain缺失的 |
| 46 | new_state_dict[k] = v # 模型自带的随机值 |
| 47 | if hasattr(model, "module"): |
| 48 | model.module.load_state_dict(new_state_dict, strict=False) |
| 49 | else: |
| 50 | model.load_state_dict(new_state_dict, strict=False) |
| 51 | return model |
| 52 | |
| 53 | go(combd, "combd") |
| 54 | model = go(sbd, "sbd") |
no test coverage detected