MCPcopy
hub / github.com/ytongbai/LVM / main

Function main

tokenize_examples/tokenize_multi_datasets_muse.py:111–176  ·  view source on GitHub ↗
(argv)

Source from the content-addressed store, hash-verified

109
110
111def main(argv):
112 assert FLAGS.input_dir != ''
113 assert FLAGS.input_regex != ''
114 assert FLAGS.output_file != ''
115
116 # Load the pre-trained vq model from the hub
117 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
118
119 net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
120 net.eval()
121
122 regex = FLAGS.input_regex.split('::')
123 input_images = match_mulitple_path_v2(FLAGS.input_dir, regex)
124
125 print(f'Found {len(input_images)} images')
126 assert len(input_images) > 0, 'No images found'
127
128 if FLAGS.max_examples > 0:
129 input_images = input_images[:FLAGS.max_examples]
130
131 random.shuffle(input_images)
132
133 dataset = MultipleImageDataset(input_images)
134 dataloader = torch.utils.data.DataLoader(
135 dataset,
136 batch_size=FLAGS.batch_size * FLAGS.n_shots,
137 shuffle=False,
138 num_workers=FLAGS.n_workers,
139 drop_last=True
140 )
141
142 total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots)
143
144 with NamedTemporaryFile() as ntf:
145 all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 256 * len(input_images[0])))
146 all_tokens[:] = 0
147
148 index = 0
149 for batch in tqdm(dataloader, ncols=0):
150 k = 0
151 for image in batch:
152 batch_size = image.shape[0]
153 image = einops.rearrange(
154 image.numpy(), 'b h w c -> b c h w'
155 )
156 image = torch.tensor(image).to(device)
157 _, tokens = net.encode(image)
158 tokens = einops.rearrange(
159 tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size
160 )
161 all_tokens[index:index + image.shape[0], k:k + 256] = tokens
162 k += 256
163 index += batch[0].shape[0]
164
165 with open(FLAGS.output_file, 'w') as fout:
166 for _ in trange(FLAGS.n_epochs, ncols=0):
167 indices = np.random.permutation(total_images).reshape(-1, FLAGS.n_shots)
168 for i in trange(indices.shape[0], ncols=0):

Callers

nothing calls this directly

Calls 6

match_mulitple_path_v2Function · 0.85
deviceMethod · 0.80
from_pretrainedMethod · 0.80
encodeMethod · 0.80
decodeMethod · 0.45

Tested by

no test coverage detected