(model, args, raw_text, seq, output_path="./samples")
| 221 | generate_images_once(model, args, raw_text, seq, num=args.batch_size, output_path=output_path) |
| 222 | |
| 223 | def super_resolution(model, args, raw_text, seq, output_path="./samples"): |
| 224 | tokenizer = get_tokenizer() |
| 225 | model.eval() |
| 226 | if not os.path.exists(output_path): |
| 227 | os.makedirs(output_path) |
| 228 | with torch.no_grad(): |
| 229 | start_time = time.time() |
| 230 | output_tokens_list = magnify(model, tokenizer, seq[-32**2:], seq[:-32**2], args) |
| 231 | |
| 232 | print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) |
| 233 | print("\nContext:", raw_text, flush=True) |
| 234 | output_file_prefix = raw_text.replace('/', '')[:20] |
| 235 | output_file = os.path.join(output_path, f"{output_file_prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg") |
| 236 | imgs = [] |
| 237 | if args.debug: |
| 238 | imgs.append(torch.nn.functional.interpolate(tokenizer.img_tokenizer.DecodeIds(seq[-32**2:]), size=(512, 512))) |
| 239 | for seq in output_tokens_list: |
| 240 | decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) |
| 241 | imgs.extend(decoded_imgs) |
| 242 | imgs = torch.cat(imgs, dim=0) |
| 243 | print("\nSave to: ", output_file, flush=True) |
| 244 | save_image(imgs, output_file, normalize=True) |
| 245 | |
| 246 | def post_selection(model, args, raw_text, seq, output_path): |
| 247 | tokenizer = get_tokenizer() |
no test coverage detected