MCPcopy
hub / github.com/ytongbai/LVM / main

Function main

tokenize_examples/tokenize_co3d_muse.py:72–136  ·  view source on GitHub ↗
(argv)

Source from the content-addressed store, hash-verified

70
71
72def main(argv):
73 assert FLAGS.input_dir != ''
74 assert FLAGS.output_file != ''
75
76 # Load the pre-trained vq model from the hub
77 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
79 net = VQGANModel.from_pretrained('/home/vqlm/muse/ckpts/laion').to(device)
80 net.eval()
81
82 dirs = []
83 for d in list_dir_with_full_path(FLAGS.input_dir):
84 if not os.path.isdir(d):
85 continue
86 for d2 in list_dir_with_full_path(d):
87 if not os.path.isdir(d2):
88 continue
89 dirs.append(d2)
90
91 image_dirs = []
92 for d in dirs:
93 image_dirs.append((
94 os.path.join(d, 'images'),
95 os.path.join(d, 'masks'),
96 os.path.join(d, 'depth_masks')
97 ))
98
99 with open(FLAGS.output_file, 'w') as fout:
100 with torch.no_grad():
101 for _ in trange(FLAGS.n_epochs, ncols=0):
102 print(image_dirs[0])
103 dataset = Co3DDataset(image_dirs, FLAGS.n_frames)
104 dataloader = torch.utils.data.DataLoader(
105 dataset,
106 batch_size=FLAGS.batch_size * FLAGS.n_shots,
107 shuffle=False,
108 num_workers=FLAGS.n_workers,
109 drop_last=True
110 )
111
112 for batch in tqdm(dataloader, ncols=0):
113 batch_shape = batch.shape[:-3]
114 batch = batch.reshape(-1, *batch.shape[-3:])
115 batch = batch.permute(0,3,1,2)
116 batch = batch.to(device)
117
118 _, tokens = net.encode(batch)
119 tokens = tokens.reshape(*batch_shape, tokens.shape[-1])
120 # batch x task x frame x token
121 tokens = einops.rearrange(
122 tokens.cpu().numpy().astype(np.int32), '(b s) t f d -> b s t f d',
123 s=FLAGS.n_shots
124 )
125
126 image_mask_tokens = np.concatenate(
127 (tokens[:, :, 0, :, :], tokens[:, :, 1, :, :]), axis=-2
128 )
129 image_depth_tokens = np.concatenate(

Callers

nothing calls this directly

Calls 5

list_dir_with_full_pathFunction · 0.90
Co3DDatasetClass · 0.85
deviceMethod · 0.80
from_pretrainedMethod · 0.80
encodeMethod · 0.80

Tested by

no test coverage detected