MCPcopy
hub / github.com/microsoft/Cream / convert_to_new_checkpoint

Function convert_to_new_checkpoint

TinyCLIP/src/open_clip/model.py:1115–1157  ·  view source on GitHub ↗
(state_dict, used_ddp=False)

Source from the content-addressed store, hash-verified

1113
1114
1115def convert_to_new_checkpoint(state_dict, used_ddp=False):
1116 if '_logit_scale.module.logit_scale' in state_dict:
1117 if not used_ddp:
1118 new_checkpoint = dict()
1119 for k, v in state_dict.items():
1120 sp = k.split('.')
1121 assert sp[1] == 'module', (sp, state_dict.keys())
1122 k = '.'.join(sp[:1] + sp[2:])
1123 new_checkpoint[k] = v
1124 state_dict = new_checkpoint
1125 return state_dict
1126 if '_logit_scale.logit_scale' in state_dict:
1127 if used_ddp:
1128 new_checkpoint = dict()
1129 for k, v in state_dict.items():
1130 sp = k.split('.')
1131 k = '.'.join(sp[:1] + ['module'] + sp[1:])
1132 new_checkpoint[k] = v
1133 state_dict = new_checkpoint
1134 return state_dict
1135 image_prefix = '_image_encoder.'
1136 text_prefix = '_text_encoder.'
1137 logit_scale_prefix = '_logit_scale.'
1138 if used_ddp:
1139 image_prefix += 'module.'
1140 text_prefix += 'module.'
1141 logit_scale_prefix += 'module.'
1142 new_checkpoint = dict()
1143 if 'module.logit_scale' in state_dict:
1144 # remove the prefix module
1145 state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
1146 if 'logit_scale' in state_dict:
1147 # old CLIP checkpoint
1148 for k, v in state_dict.items():
1149 if k.startswith('visual.'):
1150 new_checkpoint[image_prefix + k] = v
1151 elif k == 'logit_scale':
1152 new_checkpoint[logit_scale_prefix + 'logit_scale'] = v
1153 else:
1154 new_checkpoint[text_prefix + k] = v
1155 else:
1156 new_checkpoint = state_dict
1157 return new_checkpoint
1158
1159
1160def convert_weights_to_fp16(model: nn.Module):

Callers 3

remove_prefix_moduleFunction · 0.90
remove_prefix_moduleFunction · 0.90
load_state_dictMethod · 0.85

Calls

no outgoing calls

Tested by 1

remove_prefix_moduleFunction · 0.72