MCPcopy
hub / github.com/showlab/Show-o / UniversalPrompting

Class UniversalPrompting

training/prompting_utils.py:18–464  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

16import torch
17# TODO - SHOULD BE FURTHER IMPROVED.
18class UniversalPrompting():
19 def __init__(self, text_tokenizer,
20 special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"),
21 max_text_len=8000, max_seq_len=377, ignore_id=-100, cond_dropout_prob=0.1):
22 """
23 :param text_tokenizer: original text tokenizer
24 """
25 self.text_tokenizer = text_tokenizer
26 self.text_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
27 self.text_tokenizer.add_tokens(list(special_tokens))
28 self.sptids_dict = {token: torch.tensor(self.text_tokenizer.convert_tokens_to_ids([token])) for token in
29 special_tokens}
30 self.sptids_dict['<|sot|>'] = torch.tensor([self.text_tokenizer.bos_token_id])
31 self.sptids_dict['<|eot|>'] = torch.tensor([self.text_tokenizer.eos_token_id])
32 self.sptids_dict['<|pad|>'] = torch.tensor([self.text_tokenizer.pad_token_id])
33 # plus 1 because at this time we add a task token before
34 self.max_text_len = max_text_len + 1
35 self.pad_id = self.text_tokenizer.convert_tokens_to_ids('[PAD]')
36 self.ignore_id = ignore_id
37 self.cond_dropout_prob = cond_dropout_prob
38
39 def t2i_prompt(self, text_ids, image_ids, labels):
40
41 device = image_ids.device
42 sequence_ids = []
43 attention_masks = []
44 label_ids = []
45 probs = torch.rand(len(text_ids))
46 for i in range(len(text_ids)):
47
48 if len(text_ids[i]) == 0:
49 text_ids[i] = [self.text_tokenizer.bos_token_id]
50 elif text_ids[i][0] != self.text_tokenizer.bos_token_id:
51 text_ids[i] = [self.text_tokenizer.bos_token_id] + text_ids[i]
52
53 temp_ids = [int(self.sptids_dict['<|t2i|>'])] + text_ids[i] + [self.text_tokenizer.eos_token_id]
54
55 # randomly dropout text condition
56 if probs[i] < self.cond_dropout_prob:
57 temp_ids = [int(self.sptids_dict['<|t2i|>']), self.text_tokenizer.bos_token_id, self.text_tokenizer.eos_token_id]
58
59 if self.max_text_len >= len(temp_ids):
60 temp_ids = [self.pad_id] * (self.max_text_len - len(temp_ids)) + temp_ids
61 temp_masks = [0] * (self.max_text_len - len(temp_ids)) + [1] * (len(temp_ids) + image_ids.shape[-1] + 3)
62 else:
63 # should add the eos token
64 temp_ids = temp_ids[:self.max_text_len - 1] + [self.text_tokenizer.eos_token_id]
65 temp_masks = [1] * (len(temp_ids) + image_ids.shape[-1] + 3) # +2 for two special tokens
66
67 # prompting -- [task token] [sot] [text tokens] [eot] [soi] [image tokens] [eoi]
68 temp_label_ids = torch.cat([
69 # should we predict text tokens when doing image reconstruction?
70 torch.tensor(temp_ids).to(device),
71 self.sptids_dict['<|soi|>'].to(device),
72 labels[i],
73 self.sptids_dict['<|eoi|>'].to(device)
74 ], dim=0)
75

Callers 4

inference_mmu.pyFile · 0.90
inference_t2i.pyFile · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected