| 1071 | |
| 1072 | |
| 1073 | class CLIP(CLIPBase): |
| 1074 | def __init__( |
| 1075 | self, |
| 1076 | embed_dim: int, |
| 1077 | vision_cfg: CLIPVisionCfg, |
| 1078 | text_cfg: CLIPTextCfg, |
| 1079 | quick_gelu: bool = False, |
| 1080 | mask_image: bool = False, |
| 1081 | mask_text: bool = False, |
| 1082 | sparsity_warmup: int = 1000, |
| 1083 | sparsity: float = 0.25, |
| 1084 | start_sparsity: float = 0.0, |
| 1085 | ): |
| 1086 | |
| 1087 | vision_ocfg = None |
| 1088 | text_ocfg = None |
| 1089 | |
| 1090 | if isinstance(vision_cfg, dict): |
| 1091 | vision_ocfg = vision_cfg.pop('configs', None) |
| 1092 | vision_cfg = CLIPVisionCfg(**vision_cfg) |
| 1093 | |
| 1094 | if isinstance(text_cfg, dict): |
| 1095 | text_ocfg = text_cfg.pop('configs', None) |
| 1096 | text_cfg = CLIPTextCfg(**text_cfg) |
| 1097 | |
| 1098 | mask_cfg = Namespace() |
| 1099 | mask_cfg.sparsity_warmup = sparsity_warmup |
| 1100 | mask_cfg.sparsity = sparsity |
| 1101 | mask_cfg.start_sparsity = start_sparsity |
| 1102 | |
| 1103 | if vision_ocfg is None: |
| 1104 | image_encoder = ImageEncoder(embed_dim, vision_cfg, quick_gelu, |
| 1105 | l0_module_image=mask_image, |
| 1106 | mask_cfg=mask_cfg) |
| 1107 | |
| 1108 | if text_ocfg is None: |
| 1109 | text_encoder = TextEncoder(embed_dim, text_cfg, quick_gelu, |
| 1110 | l0_module_text=mask_text, mask_cfg=mask_cfg) |
| 1111 | |
| 1112 | super().__init__(image_encoder, text_encoder) |
| 1113 | |
| 1114 | |
| 1115 | def convert_to_new_checkpoint(state_dict, used_ddp=False): |
no outgoing calls
no test coverage detected