(model, args, raw_text, seq=None, num=8, query_template='{}', output_path='./samples')
| 141 | |
| 142 | |
| 143 | def generate_images_once(model, args, raw_text, seq=None, num=8, query_template='{}', output_path='./samples'): |
| 144 | tokenizer = get_tokenizer() |
| 145 | if not os.path.exists(output_path): |
| 146 | os.makedirs(output_path) |
| 147 | if seq is None: # need parse |
| 148 | img_size = 256 if args.generation_task != 'low-level super-resolution' else 128 |
| 149 | seq = _parse_and_to_tensor(raw_text, img_size=img_size, query_template=query_template) |
| 150 | model.eval() |
| 151 | with torch.no_grad(): |
| 152 | print('show raw text:', raw_text) |
| 153 | start_time = time.time() |
| 154 | if args.generation_task in ['text2image', 'low-level super-resolution']: |
| 155 | invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)] |
| 156 | elif args.generation_task == 'image2text': |
| 157 | invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)] |
| 158 | else: |
| 159 | NotImplementedError |
| 160 | |
| 161 | mbz = args.max_inference_batch_size |
| 162 | add_interlacing_beam_marks(seq, nb=min(num, mbz)) |
| 163 | assert num < mbz or num % mbz == 0 |
| 164 | output_tokens_list = [] |
| 165 | for tim in range(max(num // mbz, 1)): |
| 166 | output_tokens_list.append(filling_sequence(model, seq.clone(), args)) |
| 167 | torch.cuda.empty_cache() |
| 168 | |
| 169 | output_tokens_list = torch.cat(output_tokens_list, dim=0) |
| 170 | |
| 171 | print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) |
| 172 | print("\nContext:", raw_text, flush=True) |
| 173 | imgs, txts = [], [] |
| 174 | for seq in output_tokens_list: |
| 175 | decoded_txts, decoded_imgs = tokenizer.DecodeIds(seq.tolist()) |
| 176 | for i in range(len(decoded_imgs)): |
| 177 | if decoded_imgs[i].shape[-1] == 128: |
| 178 | decoded_imgs[i] = torch.nn.functional.interpolate(decoded_imgs[i], size=(256, 256)) |
| 179 | if args.debug: |
| 180 | imgs.extend(decoded_imgs) |
| 181 | else: |
| 182 | imgs.append(decoded_imgs[-1]) # only the last image (target) |
| 183 | txts.append(decoded_txts) |
| 184 | if args.generation_task == 'image2text': |
| 185 | print(txts) |
| 186 | return |
| 187 | if args.debug: |
| 188 | output_file_prefix = raw_text.replace('/', '')[:20] |
| 189 | output_file = os.path.join(output_path, f"{output_file_prefix}-{datetime.now().strftime('%m-%d-%H-%M-%S')}.jpg") |
| 190 | imgs = torch.cat(imgs, dim=0) |
| 191 | print(txts) |
| 192 | print("\nSave to: ", output_file, flush=True) |
| 193 | save_image(imgs, output_file, normalize=True) |
| 194 | else: |
| 195 | print("\nSave to: ", output_path, flush=True) |
| 196 | for i in range(len(imgs)): |
| 197 | save_image(imgs[i], os.path.join(output_path,f'{i}.jpg'), normalize=True) |
| 198 | os.chmod(os.path.join(output_path,f'{i}.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) |
| 199 | save_image(torch.cat(imgs, dim=0), os.path.join(output_path,f'concat.jpg'), normalize=True) |
| 200 | os.chmod(os.path.join(output_path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) |
no test coverage detected