MCPcopy
hub / github.com/zai-org/CogView / generate_images_once

Function generate_images_once

generate_samples.py:143–200  ·  view source on GitHub ↗
(model, args, raw_text, seq=None, num=8, query_template='{}', output_path='./samples')

Source from the content-addressed store, hash-verified

141
142
143def generate_images_once(model, args, raw_text, seq=None, num=8, query_template='{}', output_path='./samples'):
144 tokenizer = get_tokenizer()
145 if not os.path.exists(output_path):
146 os.makedirs(output_path)
147 if seq is None: # need parse
148 img_size = 256 if args.generation_task != 'low-level super-resolution' else 128
149 seq = _parse_and_to_tensor(raw_text, img_size=img_size, query_template=query_template)
150 model.eval()
151 with torch.no_grad():
152 print('show raw text:', raw_text)
153 start_time = time.time()
154 if args.generation_task in ['text2image', 'low-level super-resolution']:
155 invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
156 elif args.generation_task == 'image2text':
157 invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)]
158 else:
159 NotImplementedError
160
161 mbz = args.max_inference_batch_size
162 add_interlacing_beam_marks(seq, nb=min(num, mbz))
163 assert num < mbz or num % mbz == 0
164 output_tokens_list = []
165 for tim in range(max(num // mbz, 1)):
166 output_tokens_list.append(filling_sequence(model, seq.clone(), args))
167 torch.cuda.empty_cache()
168
169 output_tokens_list = torch.cat(output_tokens_list, dim=0)
170
171 print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
172 print("\nContext:", raw_text, flush=True)
173 imgs, txts = [], []
174 for seq in output_tokens_list:
175 decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
176 for i in range(len(decoded_imgs)):
177 if decoded_imgs[i].shape[-1] == 128:
178 decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(256, 256))
179 if args.debug:
180 imgs.extend(decoded_imgs)
181 else:
182 imgs.append(decoded_imgs[-1]) # only the last image (target)
183 txts.append(decoded_txts)
184 if args.generation_task == 'image2text':
185 print(txts)
186 return
187 if args.debug:
188 output_file_prefix = raw_text.replace('/', '')[:20]
189 output_file = os.path.join(output_path, f"{output_file_prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg")
190 imgs = torch.cat(imgs, dim=0)
191 print(txts)
192 print("\nSave to: ", output_file, flush=True)
193 save_image(imgs, output_file, normalize=True)
194 else:
195 print("\nSave to: ", output_path, flush=True)
196 for i in range(len(imgs)):
197 save_image(imgs[i], os.path.join(output_path,f'{i}.jpg'), normalize=True)
198 os.chmod(os.path.join(output_path,f'{i}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)
199 save_image(torch.cat(imgs, dim=0), os.path.join(output_path,f'concat.jpg'), normalize=True)
200 os.chmod(os.path.join(output_path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU)

Callers 1

Calls 5

get_tokenizerFunction · 0.90
filling_sequenceFunction · 0.90
_parse_and_to_tensorFunction · 0.85
DecodeIdsMethod · 0.45

Tested by

no test coverage detected