MCPcopy
hub / github.com/ytongbai/LVM / main

Function main

tokenize_examples/tokenize_paired_dataset_muse.py:62–129  ·  view source on GitHub ↗
(argv)

Source from the content-addressed store, hash-verified

60
61
62def main(argv):
63 assert FLAGS.input_image_dir != ''
64 assert FLAGS.output_image_dir != ''
65 assert FLAGS.output_file != ''
66
67 # Load the pre-trained vq model from the hub
68 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
69
70 net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
71 net.eval()
72
73
74 input_images = os.listdir(FLAGS.input_image_dir)
75 input_images = [i for i in input_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')]
76 input_images = [i for i in input_images if FLAGS.input_filter_key in i]
77 input_images = sorted(input_images)
78 output_images = os.listdir(FLAGS.output_image_dir)
79 output_images = [i for i in output_images if i.endswith('.png') or i.endswith('.jpg') or i.endswith('.jpeg')]
80 output_images = [i for i in output_images if FLAGS.output_filter_key in i]
81 output_images = sorted(output_images)
82
83 assert len(input_images) == len(output_images)
84
85
86 input_images = [
87 os.path.join(FLAGS.input_image_dir, s)
88 for s in input_images
89 ]
90 output_images = [
91 os.path.join(FLAGS.output_image_dir, s)
92 for s in output_images
93 ]
94
95 dataset = PairedImageDataset(input_images, output_images)
96 dataloader = torch.utils.data.DataLoader(
97 dataset,
98 batch_size=FLAGS.batch_size * FLAGS.n_shots,
99 shuffle=False,
100 num_workers=FLAGS.n_workers,
101 drop_last=True
102 )
103
104 total_images = len(input_images) - len(input_images) % (FLAGS.batch_size * FLAGS.n_shots)
105
106 with torch.no_grad():
107 with NamedTemporaryFile() as ntf:
108 all_tokens = np.memmap(ntf, dtype='i4', mode='w+', shape=(total_images, 512))
109 all_tokens[:] = 0
110
111 index = 0
112 for input_image_batch, output_image_batch in tqdm(dataloader, ncols=0):
113 _, input_token_batch = net.encode(input_image_batch.permute(0,3,1,2).to(device))
114 _, output_token_batch = net.encode(output_image_batch.permute(0, 3, 1, 2).to(device))
115
116
117 all_tokens[index:index + input_image_batch.shape[0]] = np.concatenate(
118 [input_token_batch.cpu().numpy().astype(np.int32), output_token_batch.cpu().numpy().astype(np.int32)],
119 axis=1

Callers

nothing calls this directly

Calls 5

deviceMethod · 0.80
from_pretrainedMethod · 0.80
encodeMethod · 0.80
PairedImageDatasetClass · 0.70
decodeMethod · 0.45

Tested by

no test coverage detected