(args, model, llm_tokenizer, image_features, tokens, num_new_token, raw_texts, raw_alt_texts, prefix_length, device)
| 191 | |
| 192 | |
| 193 | def gen_sample0(args, model, llm_tokenizer, image_features, tokens, num_new_token, raw_texts, raw_alt_texts, prefix_length, device): |
| 194 | model = unwrap_model(model) |
| 195 | batch_size = tokens.size(0) |
| 196 | |
| 197 | embedding_image = model.clip_project(image_features).reshape(batch_size, -1, model.gpt_embedding_size) |
| 198 | |
| 199 | # add rewrite_prompt and can be integrated into interactive_*.py |
| 200 | if hasattr(args, "rewrite_prompt"): |
| 201 | embedding_text = model.gpt.get_input_embeddings()(tokens[:, :args.rewrite_prompt]) |
| 202 | embedding_cat = torch.cat((embedding_image, embedding_text), dim=1) |
| 203 | else: |
| 204 | embedding_cat = embedding_image |
| 205 | |
| 206 | gen_ids = model.gpt.generate( |
| 207 | inputs_embeds=embedding_cat, |
| 208 | max_new_tokens=num_new_token, |
| 209 | temperature=0.2, |
| 210 | do_sample=True, |
| 211 | top_p=0.7, |
| 212 | ) |
| 213 | |
| 214 | print(gen_ids) |
| 215 | |
| 216 | gen_strs = llm_decode(llm_tokenizer, gen_ids, remove_new_line=True) |
| 217 | assert len(gen_strs) == len(raw_texts) |
| 218 | |
| 219 | logging.info(f"[alt]{raw_alt_texts[0]}") |
| 220 | logging.info(f"[gt]{raw_texts[0]}") |
| 221 | logging.info(f"[gen]{gen_strs[0]}") |
no test coverage detected