(argv)
| 49 | |
| 50 | |
| 51 | def main(argv): |
| 52 | assert FLAGS.input_dir != '' |
| 53 | assert FLAGS.output_file != '' |
| 54 | |
| 55 | # Load the pre-trained vq model from the hub |
| 56 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 57 | |
| 58 | net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device) |
| 59 | net.eval() |
| 60 | |
| 61 | videos = [] |
| 62 | for root, _, files in os.walk(FLAGS.input_dir): |
| 63 | for file in files: |
| 64 | if is_video(file): |
| 65 | videos.append(os.path.join(root, file)) |
| 66 | |
| 67 | with open(FLAGS.output_file, 'w') as fout: |
| 68 | with torch.no_grad(): |
| 69 | for epoch in trange(FLAGS.n_epochs, ncols=0): |
| 70 | for stride in tqdm(FLAGS.strides.split(','), ncols=0): |
| 71 | stride = int(stride) |
| 72 | dataset = VideoDataset(videos, n_frames=FLAGS.n_frames, stride=stride) |
| 73 | dataloader = torch.utils.data.DataLoader( |
| 74 | dataset, |
| 75 | batch_size=FLAGS.batch_size, |
| 76 | shuffle=False, |
| 77 | num_workers=FLAGS.n_workers, |
| 78 | prefetch_factor=4, |
| 79 | drop_last=True, |
| 80 | ) |
| 81 | for batch in tqdm(dataloader, ncols=0): |
| 82 | batch_size = batch.shape[0] |
| 83 | batch = einops.rearrange( |
| 84 | batch.numpy(), 'b t h w c -> (b t) c h w' |
| 85 | ) |
| 86 | batch = torch.tensor(batch).to(device) |
| 87 | _, tokens = net.encode(batch) |
| 88 | tokens = einops.rearrange( |
| 89 | tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size |
| 90 | ) |
| 91 | for i in range(batch_size): |
| 92 | data = {'tokens': b64encode(tokens[i].tobytes()).decode('utf-8'),} |
| 93 | fout.write(json.dumps(data) + '\n') |
| 94 | |
| 95 | |
| 96 |
nothing calls this directly
no test coverage detected