MCPcopy
hub / github.com/facebookresearch/MetaCLIP / inference

Function inference

altogether/infer.py:148–194  ·  view source on GitHub ↗
(args, batch_iter, model, clip_model, tokenizer)

Source from the content-addressed store, hash-verified

146
147
148def inference(args, batch_iter, model, clip_model, tokenizer):
149 outputs = {}
150
151 import time
152 t0 = time.time()
153
154 total_size = 0
155 with torch.no_grad(), torch.cuda.amp.autocast():
156 while True:
157 try:
158 batch = next(batch_iter)
159 except StopIteration:
160 print(f"qps: {total_size / (time.time() - t0)}")
161 return outputs
162
163 imgs, prompt_input_ids, uuids = to_device(batch, args.clipcap_args["device"])
164 batch_size = imgs.size(0)
165 prefix_length = args.clipcap_args["prefix_length"]
166 pad_token_id = args.clipcap_args["pad_token_id"]
167
168 assert hasattr(args, "rewrite_prompt")
169 num_new_token = args.max_seq_len - args.rewrite_prompt
170
171 image_features = clip_model.encode_image(imgs)
172 image_features = F.normalize(image_features, dim=-1)
173
174 embedding_image = model.clip_project(image_features).view(batch_size, args.clipcap_args["prefix_length"], -1)
175
176 embedding_text = model.gpt.get_input_embeddings()(prompt_input_ids)
177 embedding_cat = torch.cat((embedding_image, embedding_text), dim=1)
178
179 gen_ids = model.gpt.generate(
180 inputs_embeds=embedding_cat,
181 max_new_tokens=num_new_token,
182 temperature=0.2,
183 do_sample=True,
184 top_p=0.7,
185 use_cache = True,
186 )
187
188 cap_strs = llm_decode(tokenizer, gen_ids, remove_new_line=True)
189 cap_strs = [cap_str.split(tokenizer.eos_token)[0] for cap_str in cap_strs]
190
191 for img_id, cap_str in enumerate(cap_strs):
192 uuid = uuids[img_id]
193 outputs[uuid] = {"altogether": f"{cap_strs[img_id]}"}
194 total_size += batch_size
195
196
197def main(config_name, checkpoint_name, batch_size, data_path, cap_path, todo):

Callers 1

mainFunction · 0.85

Calls 3

to_deviceFunction · 0.90
llm_decodeFunction · 0.90
encode_imageMethod · 0.80

Tested by

no test coverage detected