(state_dict, used_ddp=False)
| 1113 | |
| 1114 | |
| 1115 | def convert_to_new_checkpoint(state_dict, used_ddp=False): |
| 1116 | if '_logit_scale.module.logit_scale' in state_dict: |
| 1117 | if not used_ddp: |
| 1118 | new_checkpoint = dict() |
| 1119 | for k, v in state_dict.items(): |
| 1120 | sp = k.split('.') |
| 1121 | assert sp[1] == 'module', (sp, state_dict.keys()) |
| 1122 | k = '.'.join(sp[:1] + sp[2:]) |
| 1123 | new_checkpoint[k] = v |
| 1124 | state_dict = new_checkpoint |
| 1125 | return state_dict |
| 1126 | if '_logit_scale.logit_scale' in state_dict: |
| 1127 | if used_ddp: |
| 1128 | new_checkpoint = dict() |
| 1129 | for k, v in state_dict.items(): |
| 1130 | sp = k.split('.') |
| 1131 | k = '.'.join(sp[:1] + ['module'] + sp[1:]) |
| 1132 | new_checkpoint[k] = v |
| 1133 | state_dict = new_checkpoint |
| 1134 | return state_dict |
| 1135 | image_prefix = '_image_encoder.' |
| 1136 | text_prefix = '_text_encoder.' |
| 1137 | logit_scale_prefix = '_logit_scale.' |
| 1138 | if used_ddp: |
| 1139 | image_prefix += 'module.' |
| 1140 | text_prefix += 'module.' |
| 1141 | logit_scale_prefix += 'module.' |
| 1142 | new_checkpoint = dict() |
| 1143 | if 'module.logit_scale' in state_dict: |
| 1144 | # remove the prefix module |
| 1145 | state_dict = {k[len('module.'):]: v for k, v in state_dict.items()} |
| 1146 | if 'logit_scale' in state_dict: |
| 1147 | # old CLIP checkpoint |
| 1148 | for k, v in state_dict.items(): |
| 1149 | if k.startswith('visual.'): |
| 1150 | new_checkpoint[image_prefix + k] = v |
| 1151 | elif k == 'logit_scale': |
| 1152 | new_checkpoint[logit_scale_prefix + 'logit_scale'] = v |
| 1153 | else: |
| 1154 | new_checkpoint[text_prefix + k] = v |
| 1155 | else: |
| 1156 | new_checkpoint = state_dict |
| 1157 | return new_checkpoint |
| 1158 | |
| 1159 | |
| 1160 | def convert_weights_to_fp16(model: nn.Module): |
no outgoing calls