MCPcopy
hub / github.com/facebookresearch/MetaCLIP / gen_sample0

Function gen_sample0

src/training/train_altogether.py:193–221  ·  view source on GitHub ↗
(args, model, llm_tokenizer, image_features, tokens, num_new_token, raw_texts, raw_alt_texts, prefix_length, device)

Source from the content-addressed store, hash-verified

191
192
193def gen_sample0(args, model, llm_tokenizer, image_features, tokens, num_new_token, raw_texts, raw_alt_texts, prefix_length, device):
194 model = unwrap_model(model)
195 batch_size = tokens.size(0)
196
197 embedding_image = model.clip_project(image_features).reshape(batch_size, -1, model.gpt_embedding_size)
198
199 # add rewrite_prompt and can be integrated into interactive_*.py
200 if hasattr(args, "rewrite_prompt"):
201 embedding_text = model.gpt.get_input_embeddings()(tokens[:, :args.rewrite_prompt])
202 embedding_cat = torch.cat((embedding_image, embedding_text), dim=1)
203 else:
204 embedding_cat = embedding_image
205
206 gen_ids = model.gpt.generate(
207 inputs_embeds=embedding_cat,
208 max_new_tokens=num_new_token,
209 temperature=0.2,
210 do_sample=True,
211 top_p=0.7,
212 )
213
214 print(gen_ids)
215
216 gen_strs = llm_decode(llm_tokenizer, gen_ids, remove_new_line=True)
217 assert len(gen_strs) == len(raw_texts)
218
219 logging.info(f"[alt]{raw_alt_texts[0]}")
220 logging.info(f"[gt]{raw_texts[0]}")
221 logging.info(f"[gen]{gen_strs[0]}")

Callers 1

train_altogetherFunction · 0.85

Calls 2

unwrap_modelFunction · 0.90
llm_decodeFunction · 0.85

Tested by

no test coverage detected