(argv)
| 60 | |
| 61 | |
| 62 | def main(argv): |
| 63 | assert FLAGS.input_image_dir != '' |
| 64 | assert FLAGS.output_image_dir != '' |
| 65 | assert FLAGS.output_file != '' |
| 66 | |
| 67 | # Load the pre-trained vq model from the hub |
| 68 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 69 | |
| 70 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) |
| 71 | net.eval() |
| 72 | |
| 73 | |
| 74 | input_images = os.listdir(FLAGS.input_image_dir) |
| 75 | input_images = [i for i in input_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')] |
| 76 | input_images = [i for i in input_images if FLAGS.input_filter_key in i] |
| 77 | input_images = sorted(input_images) |
| 78 | output_images = os.listdir(FLAGS.output_image_dir) |
| 79 | output_images = [i for i in output_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')] |
| 80 | output_images = [i for i in output_images if FLAGS.output_filter_key in i] |
| 81 | output_images = sorted(output_images) |
| 82 | |
| 83 | assert len(input_images) == len(output_images) |
| 84 | |
| 85 | |
| 86 | input_images = [ |
| 87 | os.path.join(FLAGS.input_image_dir, s) |
| 88 | for s in input_images |
| 89 | ] |
| 90 | output_images = [ |
| 91 | os.path.join(FLAGS.output_image_dir, s) |
| 92 | for s in output_images |
| 93 | ] |
| 94 | |
| 95 | dataset = PairedImageDataset(input_images, output_images) |
| 96 | dataloader = torch.utils.data.DataLoader( |
| 97 | dataset, |
| 98 | batch_size=FLAGS.batch_size * FLAGS.n_shots, |
| 99 | shuffle=False, |
| 100 | num_workers=FLAGS.n_workers, |
| 101 | drop_last=True |
| 102 | ) |
| 103 | |
| 104 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots) |
| 105 | |
| 106 | with torch.no_grad(): |
| 107 | with NamedTemporaryFile() as ntf: |
| 108 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512)) |
| 109 | all_tokens[:] = 0 |
| 110 | |
| 111 | index = 0 |
| 112 | for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0): |
| 113 | _, input_token_batch = net.encode(input_image_batch.permute(0,3,1,2).to(device)) |
| 114 | _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device)) |
| 115 | |
| 116 | |
| 117 | all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate( |
| 118 | [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)], |
| 119 | axis=1 |
nothing calls this directly
no test coverage detected