MCPcopy
hub / github.com/zai-org/CogView / super_resolution

Function super_resolution

generate_samples.py:223–244  ·  view source on GitHub ↗
(model, args, raw_text, seq, output_path="./samples")

Source from the content-addressed store, hash-verified

221 generate_images_once(model, args, raw_text, seq, num=args.batch_size, output_path=output_path)
222
223def super_resolution(model, args, raw_text, seq, output_path="./samples"):
224 tokenizer = get_tokenizer()
225 model.eval()
226 if not os.path.exists(output_path):
227 os.makedirs(output_path)
228 with torch.no_grad():
229 start_time = time.time()
230 output_tokens_list = magnify(model, tokenizer, seq[-32**2:], seq[:-32**2], args)
231
232 print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
233 print("\nContext:", raw_text, flush=True)
234 output_file_prefix = raw_text.replace('/', '')[:20]
235 output_file = os.path.join(output_path, f"{output_file_prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg")
236 imgs = []
237 if args.debug:
238 imgs.append(torch.nn.functional.interpolate(tokenizer.img_tokenizer.DecodeIds(seq[-32**2:]), size=(512, 512)))
239 for seq in output_tokens_list:
240 decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist())
241 imgs.extend(decoded_imgs)
242 imgs = torch.cat(imgs, dim=0)
243 print("\nSave to: ", output_file, flush=True)
244 save_image(imgs, output_file, normalize=True)
245
246def post_selection(model, args, raw_text, seq, output_path):
247 tokenizer = get_tokenizer()

Callers 1

Calls 3

get_tokenizerFunction · 0.90
magnifyFunction · 0.90
DecodeIdsMethod · 0.45

Tested by

no test coverage detected