MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

tools/scoring/optical_flow/inference.py:141–242  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

139
140@torch.no_grad()
141def main():
142 args = parse_args()
143
144 meta_path = args.meta_path
145 if not os.path.exists(meta_path):
146 print(f"Meta file '{meta_path}' not found. Exit.")
147 exit()
148
149 wo_ext, ext = os.path.splitext(meta_path)
150 out_path = f"{wo_ext}_flow{ext}"
151 if args.skip_if_existing and os.path.exists(out_path):
152 print(f"Output meta file '{out_path}' already exists. Exit.")
153 exit()
154
155 torch.backends.cudnn.deterministic = True
156 torch.backends.cudnn.benchmark = False
157 dist.init_process_group(backend="nccl", timeout=timedelta(hours=24))
158 torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
159
160 # build model
161 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
162 model = UniMatch(
163 feature_channels=128,
164 num_scales=2,
165 upsample_factor=4,
166 num_head=1,
167 ffn_dim_expansion=4,
168 num_transformer_layers=6,
169 reg_refine=True,
170 task="flow",
171 )
172 ckpt = torch.load("./pretrained_models/unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth")
173 model.load_state_dict(ckpt["model"])
174 model = model.to(device)
175
176 # build dataset
177 NUM_FRAMES = 10
178 frames_inds = [15 * i for i in range(0, NUM_FRAMES)]
179 dataset = VideoTextDataset(meta_path=meta_path, frame_inds=frames_inds)
180 dataloader = DataLoader(
181 dataset,
182 batch_size=args.bs,
183 num_workers=args.num_workers,
184 sampler=DistributedSampler(
185 dataset,
186 num_replicas=dist.get_world_size(),
187 rank=dist.get_rank(),
188 shuffle=False,
189 drop_last=False,
190 ),
191 )
192
193 # compute optical flow scores
194 indices_list = []
195 scores_list = []
196 model.eval()
197 for batch in tqdm(dataloader, disable=dist.get_rank() != 0):
198 indices = batch["index"]

Callers 1

inference.pyFile · 0.70

Calls 8

UniMatchClass · 0.90
tqdmFunction · 0.85
toMethod · 0.80
parse_argsFunction · 0.70
VideoTextDatasetClass · 0.70
merge_scoresFunction · 0.70
deviceMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected