(
self,
tasks,
stream: bool = True,
use_tqdm: bool = True,
metrics: Optional[List[Metric]] = None
)
| 112 | return await asyncio.gather(*tasks) |
| 113 | |
| 114 | def _batch_infer_stream( |
| 115 | self, |
| 116 | tasks, |
| 117 | stream: bool = True, |
| 118 | use_tqdm: bool = True, |
| 119 | metrics: Optional[List[Metric]] = None |
| 120 | ) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: |
| 121 | |
| 122 | prog_bar = tqdm(total=len(tasks), dynamic_ncols=True, disable=not use_tqdm) |
| 123 | if stream: |
| 124 | return [self.async_iter_to_iter(task, prog_bar, metrics) for task in tasks] |
| 125 | else: |
| 126 | |
| 127 | async def _new_run(task): |
| 128 | try: |
| 129 | res = await task |
| 130 | except Exception as e: |
| 131 | if getattr(self, 'strict', True): |
| 132 | raise |
| 133 | res = e |
| 134 | prog_bar.update() |
| 135 | self._update_metrics(res, metrics) |
| 136 | return res |
| 137 | |
| 138 | new_tasks = [_new_run(task) for task in tasks] |
| 139 | try: |
| 140 | loop = asyncio.get_event_loop() |
| 141 | except RuntimeError: |
| 142 | loop = asyncio.new_event_loop() |
| 143 | asyncio.set_event_loop(loop) |
| 144 | return loop.run_until_complete(self.batch_run(new_tasks)) |
| 145 | |
| 146 | @staticmethod |
| 147 | def _get_usage_info(num_prompt_tokens: int, num_generated_tokens: int) -> UsageInfo: |
no test coverage detected