(argv)
| 70 | |
| 71 | |
| 72 | def main(argv): |
| 73 | assert FLAGS.input_dir != '' |
| 74 | assert FLAGS.output_file != '' |
| 75 | |
| 76 | # Load the pre-trained vq model from the hub |
| 77 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 78 | |
| 79 | net = VQGANModel.from_pretrained('/home/vqlm/muse/ckpts/laion').to(device) |
| 80 | net.eval() |
| 81 | |
| 82 | dirs = [] |
| 83 | for d in list_dir_with_full_path(FLAGS.input_dir): |
| 84 | if not os.path.isdir(d): |
| 85 | continue |
| 86 | for d2 in list_dir_with_full_path(d): |
| 87 | if not os.path.isdir(d2): |
| 88 | continue |
| 89 | dirs.append(d2) |
| 90 | |
| 91 | image_dirs = [] |
| 92 | for d in dirs: |
| 93 | image_dirs.append(( |
| 94 | os.path.join(d, 'images'), |
| 95 | os.path.join(d, 'masks'), |
| 96 | os.path.join(d, 'depth_masks') |
| 97 | )) |
| 98 | |
| 99 | with open(FLAGS.output_file, 'w') as fout: |
| 100 | with torch.no_grad(): |
| 101 | for _ in trange(FLAGS.n_epochs, ncols=0): |
| 102 | print(image_dirs[0]) |
| 103 | dataset = Co3DDataset(image_dirs, FLAGS.n_frames) |
| 104 | dataloader = torch.utils.data.DataLoader( |
| 105 | dataset, |
| 106 | batch_size=FLAGS.batch_size * FLAGS.n_shots, |
| 107 | shuffle=False, |
| 108 | num_workers=FLAGS.n_workers, |
| 109 | drop_last=True |
| 110 | ) |
| 111 | |
| 112 | for batch in tqdm(dataloader, ncols=0): |
| 113 | batch_shape = batch.shape[:-3] |
| 114 | batch = batch.reshape(-1, *batch.shape[-3:]) |
| 115 | batch = batch.permute(0,3,1,2) |
| 116 | batch = batch.to(device) |
| 117 | |
| 118 | _, tokens = net.encode(batch) |
| 119 | tokens = tokens.reshape(*batch_shape, tokens.shape[-1]) |
| 120 | # batch x task x frame x token |
| 121 | tokens = einops.rearrange( |
| 122 | tokens.cpu().numpy().astype(np.int32), '(b s) t f d -> b s t f d', |
| 123 | s=FLAGS.n_shots |
| 124 | ) |
| 125 | |
| 126 | image_mask_tokens = np.concatenate( |
| 127 | (tokens[:, :, 0, :, :], tokens[:, :, 1, :, :]), axis=-2 |
| 128 | ) |
| 129 | image_depth_tokens = np.concatenate( |
nothing calls this directly
no test coverage detected