| 1047 | return self._logit_scale.named_parameters() |
| 1048 | |
| 1049 | def load_state_dict(self, state_dict, strict=True): |
| 1050 | state_dict = convert_to_new_checkpoint(state_dict, self.used_ddp) |
| 1051 | if not any(k.startswith('_image_encoder') for k in state_dict.keys()): |
| 1052 | self.use_teacher_image() |
| 1053 | |
| 1054 | for m in ['module.', '']: |
| 1055 | flag = f'_image_encoder.{m}visual.model.head.0.weight' |
| 1056 | if flag in state_dict: |
| 1057 | # LN |
| 1058 | state_dict[f'_image_encoder.{m}visual.ln_post.weight'] = state_dict.pop( |
| 1059 | f'_image_encoder.{m}visual.model.head.0.weight') |
| 1060 | state_dict[f'_image_encoder.{m}visual.ln_post.bias'] = state_dict.pop( |
| 1061 | f'_image_encoder.{m}visual.model.head.0.bias') |
| 1062 | # FC |
| 1063 | state_dict[f'_image_encoder.{m}visual.proj'] = state_dict.pop( |
| 1064 | f'_image_encoder.{m}visual.model.head.1.weight').T |
| 1065 | new_state_dict = state_dict.copy() |
| 1066 | for k, v in new_state_dict.items(): |
| 1067 | if '.module' in k: |
| 1068 | state_dict[k.replace('.module', '')] = v |
| 1069 | state_dict.pop(k) |
| 1070 | super().load_state_dict(state_dict, strict=strict) |
| 1071 | |
| 1072 | |
| 1073 | class CLIP(CLIPBase): |