(args, batch_iter, model, clip_model, tokenizer)
| 146 | |
| 147 | |
| 148 | def inference(args, batch_iter, model, clip_model, tokenizer): |
| 149 | outputs = {} |
| 150 | |
| 151 | import time |
| 152 | t0 = time.time() |
| 153 | |
| 154 | total_size = 0 |
| 155 | with torch.no_grad(), torch.cuda.amp.autocast(): |
| 156 | while True: |
| 157 | try: |
| 158 | batch = next(batch_iter) |
| 159 | except StopIteration: |
| 160 | print(f"qps: {total_size / (time.time() - t0)}") |
| 161 | return outputs |
| 162 | |
| 163 | imgs, prompt_input_ids, uuids = to_device(batch, args.clipcap_args["device"]) |
| 164 | batch_size = imgs.size(0) |
| 165 | prefix_length = args.clipcap_args["prefix_length"] |
| 166 | pad_token_id = args.clipcap_args["pad_token_id"] |
| 167 | |
| 168 | assert hasattr(args, "rewrite_prompt") |
| 169 | num_new_token = args.max_seq_len - args.rewrite_prompt |
| 170 | |
| 171 | image_features = clip_model.encode_image(imgs) |
| 172 | image_features = F.normalize(image_features, dim=-1) |
| 173 | |
| 174 | embedding_image = model.clip_project(image_features).view(batch_size, args.clipcap_args["prefix_length"], -1) |
| 175 | |
| 176 | embedding_text = model.gpt.get_input_embeddings()(prompt_input_ids) |
| 177 | embedding_cat = torch.cat((embedding_image, embedding_text), dim=1) |
| 178 | |
| 179 | gen_ids = model.gpt.generate( |
| 180 | inputs_embeds=embedding_cat, |
| 181 | max_new_tokens=num_new_token, |
| 182 | temperature=0.2, |
| 183 | do_sample=True, |
| 184 | top_p=0.7, |
| 185 | use_cache = True, |
| 186 | ) |
| 187 | |
| 188 | cap_strs = llm_decode(tokenizer, gen_ids, remove_new_line=True) |
| 189 | cap_strs = [cap_str.split(tokenizer.eos_token)[0] for cap_str in cap_strs] |
| 190 | |
| 191 | for img_id, cap_str in enumerate(cap_strs): |
| 192 | uuid = uuids[img_id] |
| 193 | outputs[uuid] = {"altogether": f"{cap_strs[img_id]}"} |
| 194 | total_size += batch_size |
| 195 | |
| 196 | |
| 197 | def main(config_name, checkpoint_name, batch_size, data_path, cap_path, todo): |
no test coverage detected