MCPcopy
hub / github.com/ytongbai/LVM / main

Function main

tokenize_examples/tokenize_video_muse.py:51–93  ·  view source on GitHub ↗
(argv)

Source from the content-addressed store, hash-verified

49
50
51def main(argv):
52 assert FLAGS.input_dir != ''
53 assert FLAGS.output_file != ''
54
55 # Load the pre-trained vq model from the hub
56 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
58 net = VQGANModel.from_pretrained('vqlm/muse/ckpts/laion').to(device)
59 net.eval()
60
61 videos = []
62 for root, _, files in os.walk(FLAGS.input_dir):
63 for file in files:
64 if is_video(file):
65 videos.append(os.path.join(root, file))
66
67 with open(FLAGS.output_file, 'w') as fout:
68 with torch.no_grad():
69 for epoch in trange(FLAGS.n_epochs, ncols=0):
70 for stride in tqdm(FLAGS.strides.split(','), ncols=0):
71 stride = int(stride)
72 dataset = VideoDataset(videos, n_frames=FLAGS.n_frames, stride=stride)
73 dataloader = torch.utils.data.DataLoader(
74 dataset,
75 batch_size=FLAGS.batch_size,
76 shuffle=False,
77 num_workers=FLAGS.n_workers,
78 prefetch_factor=4,
79 drop_last=True,
80 )
81 for batch in tqdm(dataloader, ncols=0):
82 batch_size = batch.shape[0]
83 batch = einops.rearrange(
84 batch.numpy(), 'b t h w c -> (b t) c h w'
85 )
86 batch = torch.tensor(batch).to(device)
87 _, tokens = net.encode(batch)
88 tokens = einops.rearrange(
89 tokens.cpu().numpy().astype(np.int32), '(b t) d -> b (t d)', b=batch_size
90 )
91 for i in range(batch_size):
92 data = {'tokens': b64encode(tokens[i].tobytes()).decode('utf-8'),}
93 fout.write(json.dumps(data) + '\n')
94
95
96

Callers

nothing calls this directly

Calls 6

is_videoFunction · 0.90
deviceMethod · 0.80
from_pretrainedMethod · 0.80
encodeMethod · 0.80
VideoDatasetClass · 0.70
decodeMethod · 0.45

Tested by

no test coverage detected