| 12 | |
| 13 | |
| 14 | class ClipModel(object): |
| 15 | |
| 16 | CLIP_MODELS = [ |
| 17 | 'RN50', |
| 18 | 'RN101', |
| 19 | 'RN50x4', |
| 20 | 'RN50x16', |
| 21 | 'RN50x64', |
| 22 | 'ViT-B/32', |
| 23 | 'ViT-B/16', |
| 24 | 'ViT-L/14', |
| 25 | 'ViT-L/14@336px' |
| 26 | ] |
| 27 | |
| 28 | def __init__(self, model_name='Vit-B/32', device='cuda', logger=None): |
| 29 | self.device = device |
| 30 | self.logger = logger |
| 31 | if type(model_name) is int: |
| 32 | model_name = self.index_to_model(model_name) |
| 33 | self.model, self.preprocess = clip.load( |
| 34 | model_name, device=device, jit=False) |
| 35 | self.model.eval() |
| 36 | self.model.to(device) |
| 37 | self.model_name = model_name |
| 38 | |
| 39 | def index_to_model(self, index): |
| 40 | return self.CLIP_MODELS[index] |
| 41 | |
| 42 | @staticmethod |
| 43 | def get_model_name_by_index(index): |
| 44 | name = ClipModel.CLIP_MODELS[index] |
| 45 | name = name.replace('/', '_') |
| 46 | return name |
| 47 | |
| 48 | def get_image_features(self, image, need_preprocess=False): |
| 49 | if need_preprocess: |
| 50 | image = self.preprocess(image).unsqueeze(0).to(self.device) |
| 51 | with torch.no_grad(): |
| 52 | image_features = self.model.encode_image(image) |
| 53 | return image_features |
| 54 | |
| 55 | def get_text_feature(self, text): |
| 56 | text = clip.tokenize(text).to(self.device) |
| 57 | with torch.no_grad(): |
| 58 | text_features = self.model.encode_text(text) |
| 59 | return text_features |
| 60 | |
| 61 | def get_text_features_list(self, texts, train=False): |
| 62 | if train: |
| 63 | text_inputs = torch.cat([clip.tokenize(c) |
| 64 | for c in texts]).to(self.device) |
| 65 | text_features = self.model.encode_text(text_inputs) |
| 66 | else: |
| 67 | with torch.no_grad(): |
| 68 | text_inputs = torch.cat([clip.tokenize(c) |
| 69 | for c in texts]).to(self.device) |
| 70 | text_features = self.model.encode_text(text_inputs) |
| 71 | |