| 248 | |
| 249 | |
| 250 | def inference_loop(args, model, dataset, q: Queue): |
| 251 | dataloader = DataLoader( |
| 252 | dataset, |
| 253 | num_workers=2, |
| 254 | batch_size=args.batch_size, |
| 255 | collate_fn=CSVDataset.collate_fn, |
| 256 | pin_memory=True, |
| 257 | sampler=DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=False), |
| 258 | ) |
| 259 | |
| 260 | for i, (batch, max_tokens) in enumerate(tqdm(dataloader, disable=dist.get_rank() != 0)): |
| 261 | try: |
| 262 | if batch is None: |
| 263 | raise Exception("Video not loaded properly") |
| 264 | preds = infer( |
| 265 | model, |
| 266 | batch, |
| 267 | max_tokens=max_tokens, |
| 268 | ) |
| 269 | except Exception as e: |
| 270 | logger.error(f"error at rank {dist.get_rank()} sample {i}: {str(e)}") |
| 271 | traceback.print_exception(e) |
| 272 | # preds = args.error_message duplicated for each video in the batch |
| 273 | preds = [args.error_message] * len(batch) |
| 274 | q.put(preds) |
| 275 | # finish the queue |
| 276 | q.put(None) |
| 277 | |
| 278 | |
| 279 | def post_process_loop(processor, q: Queue, result_q: Queue): |