(model, args)
| 200 | os.chmod(os.path.join(output_path,f'concat.jpg'), stat.S_IRWXO+stat.S_IRWXG+stat.S_IRWXU) |
| 201 | |
| 202 | def generate_images_continually(model, args): |
| 203 | if args.generation_task == 'text2image': |
| 204 | query_template = '[ROI1] {} [BASE] [BOI1] [MASK]*1024' |
| 205 | elif args.generation_task == 'image2text': |
| 206 | query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] [MASK]*20' |
| 207 | elif args.generation_task == 'low-level super-resolution': |
| 208 | query_template = '[ROI1] {} [BASE] [BOI1] [Image]{} [EOI1] [ROI2] [POS0] [BASE] [BOI2] [MASK]*1024' |
| 209 | elif args.generation_task == 'super-resolution': |
| 210 | query_template = '[ROI1] {} [BASE] [BOI1] [Image]{}' |
| 211 | elif args.generation_task == 'post-selection': |
| 212 | query_template = '[BASE] [BOI1] [Image]{} [EOI1] [ROI1] {}' |
| 213 | else: |
| 214 | raise NotImplementedError |
| 215 | for raw_text, seq, output_path in get_context(args, query_template): |
| 216 | if args.generation_task == 'super-resolution': |
| 217 | super_resolution(model, args, raw_text, seq, output_path=output_path) |
| 218 | elif args.generation_task == 'post-selection': |
| 219 | post_selection(model, args, raw_text, seq, output_path=output_path) |
| 220 | else: |
| 221 | generate_images_once(model, args, raw_text, seq, num=args.batch_size, output_path=output_path) |
| 222 | |
| 223 | def super_resolution(model, args, raw_text, seq, output_path="./samples"): |
| 224 | tokenizer = get_tokenizer() |
no test coverage detected