(argv)
| 109 | |
| 110 | |
| 111 | def main(argv): |
| 112 | assert FLAGS.input_dir != '' |
| 113 | assert FLAGS.input_regex != '' |
| 114 | assert FLAGS.output_file != '' |
| 115 | |
| 116 | # Load the pre-trained vq model from the hub |
| 117 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 118 | |
| 119 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) |
| 120 | net.eval() |
| 121 | |
| 122 | regex = FLAGS.input_regex.split('::') |
| 123 | input_images = match_mulitple_path_v2(FLAGS.input_dir, regex) |
| 124 | |
| 125 | print(f'Found {len(input_images)} images') |
| 126 | assert len(input_images) > 0, 'No images found' |
| 127 | |
| 128 | if FLAGS.max_examples > 0: |
| 129 | input_images = input_images[:FLAGS.max_examples] |
| 130 | |
| 131 | random.shuffle(input_images) |
| 132 | |
| 133 | dataset = MultipleImageDataset(input_images) |
| 134 | dataloader = torch.utils.data.DataLoader( |
| 135 | dataset, |
| 136 | batch_size=FLAGS.batch_size * FLAGS.n_shots, |
| 137 | shuffle=False, |
| 138 | num_workers=FLAGS.n_workers, |
| 139 | drop_last=True |
| 140 | ) |
| 141 | |
| 142 | total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots) |
| 143 | |
| 144 | with NamedTemporaryFile() as ntf: |
| 145 | all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 256 * len(input_images[0]))) |
| 146 | all_tokens[:] = 0 |
| 147 | |
| 148 | index = 0 |
| 149 | for batch in tqdm(dataloader, ncols=0): |
| 150 | k = 0 |
| 151 | for image in batch: |
| 152 | batch_size = image.shape[0] |
| 153 | image = einops.rearrange( |
| 154 | image.numpy(), 'b h w c -> b c h w' |
| 155 | ) |
| 156 | image = torch.tensor(image).to(device) |
| 157 | _, tokens = net.encode(image) |
| 158 | tokens = einops.rearrange( |
| 159 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size |
| 160 | ) |
| 161 | all_tokens[index:index + image.shape[0], k:k + 256] = tokens |
| 162 | k += 256 |
| 163 | index += batch[0].shape[0] |
| 164 | |
| 165 | with open(FLAGS.output_file, 'w') as fout: |
| 166 | for _ in trange(FLAGS.n_epochs, ncols=0): |
| 167 | indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots) |
| 168 | for i in trange(indices.shape[0], ncols=0): |
nothing calls this directly
no test coverage detected