MCPcopy
hub / github.com/jindongwang/transferlearning / ClipModel

Class ClipModel

code/clip/clip_model.py:14–160  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

12
13
14class ClipModel(object):
15
16 CLIP_MODELS = [
17 'RN50',
18 'RN101',
19 'RN50x4',
20 'RN50x16',
21 'RN50x64',
22 'ViT-B/32',
23 'ViT-B/16',
24 'ViT-L/14',
25 'ViT-L/14@336px'
26 ]
27
28 def __init__(self, model_name='Vit-B/32', device='cuda', logger=None):
29 self.device = device
30 self.logger = logger
31 if type(model_name) is int:
32 model_name = self.index_to_model(model_name)
33 self.model, self.preprocess = clip.load(
34 model_name, device=device, jit=False)
35 self.model.eval()
36 self.model.to(device)
37 self.model_name = model_name
38
39 def index_to_model(self, index):
40 return self.CLIP_MODELS[index]
41
42 @staticmethod
43 def get_model_name_by_index(index):
44 name = ClipModel.CLIP_MODELS[index]
45 name = name.replace('/', '_')
46 return name
47
48 def get_image_features(self, image, need_preprocess=False):
49 if need_preprocess:
50 image = self.preprocess(image).unsqueeze(0).to(self.device)
51 with torch.no_grad():
52 image_features = self.model.encode_image(image)
53 return image_features
54
55 def get_text_feature(self, text):
56 text = clip.tokenize(text).to(self.device)
57 with torch.no_grad():
58 text_features = self.model.encode_text(text)
59 return text_features
60
61 def get_text_features_list(self, texts, train=False):
62 if train:
63 text_inputs = torch.cat([clip.tokenize(c)
64 for c in texts]).to(self.device)
65 text_features = self.model.encode_text(text_inputs)
66 else:
67 with torch.no_grad():
68 text_inputs = torch.cat([clip.tokenize(c)
69 for c in texts]).to(self.device)
70 text_features = self.model.encode_text(text_inputs)
71

Callers 1

mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected