MCPcopy
hub / github.com/hpcaitech/Open-Sora / inference_loop

Function inference_loop

tools/caption/pllava_dir/caption_pllava.py:250–276  ·  view source on GitHub ↗
(args, model, dataset, q: Queue)

Source from the content-addressed store, hash-verified

248
249
250def inference_loop(args, model, dataset, q: Queue):
251 dataloader = DataLoader(
252 dataset,
253 num_workers=2,
254 batch_size=args.batch_size,
255 collate_fn=CSVDataset.collate_fn,
256 pin_memory=True,
257 sampler=DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False),
258 )
259
260 for i, (batch, max_tokens) in enumerate(tqdm(dataloader, disable=dist.get_rank() != 0)):
261 try:
262 if batch is None:
263 raise Exception("Video not loaded properly")
264 preds = infer(
265 model,
266 batch,
267 max_tokens=max_tokens,
268 )
269 except Exception as e:
270 logger.error(f"error at rank {dist.get_rank()} sample {i}: {str(e)}")
271 traceback.print_exception(e)
272 # preds = args.error_message duplicated for each video in the batch
273 preds = [args.error_message] * len(batch)
274 q.put(preds)
275 # finish the queue
276 q.put(None)
277
278
279def post_process_loop(processor, q: Queue, result_q: Queue):

Callers 1

mainFunction · 0.85

Calls 2

tqdmFunction · 0.85
inferFunction · 0.85

Tested by

no test coverage detected