(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output)
| 143 | self.csv = csv |
| 144 | |
| 145 | def process_request(self, requests, profiler: Profiler, temperature, top_p, top_k, stream_output): |
| 146 | |
| 147 | prompts = [prompt for prompt, _, _ in requests] |
| 148 | gen_configs = [ |
| 149 | GenerationConfig(temperature=temperature, |
| 150 | top_p=top_p, |
| 151 | top_k=top_k, |
| 152 | ignore_eos=True, |
| 153 | do_sample=False, |
| 154 | return_routed_experts=self.return_routed_experts, |
| 155 | max_new_tokens=output_len) for _, _, output_len in requests |
| 156 | ] |
| 157 | |
| 158 | sess: list[Session] = [] |
| 159 | for _, input_len, output_len in requests: |
| 160 | sess.append(profiler.new_session(input_len, output_len)) |
| 161 | |
| 162 | def _to_status(finish_reason): |
| 163 | if finish_reason == 'length': |
| 164 | return Session.SUCCESS |
| 165 | else: |
| 166 | return Session.FAIL |
| 167 | |
| 168 | profiler.start() |
| 169 | |
| 170 | for s in sess: |
| 171 | s.tick(0) |
| 172 | |
| 173 | if stream_output: |
| 174 | pbar = tqdm(total=len(requests)) |
| 175 | for output in self.pipe.stream_infer(prompts, gen_config=gen_configs, do_preprocess=False): |
| 176 | index = output.index |
| 177 | n_token = output.generate_token_len |
| 178 | finish_reason = output.finish_reason |
| 179 | sess[index].tick(n_token) |
| 180 | if finish_reason is not None: |
| 181 | sess[index].finish(_to_status(finish_reason)) |
| 182 | pbar.update(1) |
| 183 | pbar.close() |
| 184 | else: |
| 185 | for output in self.pipe(prompts, gen_configs, do_preprocess=False, use_tqdm=True): |
| 186 | index = output.index |
| 187 | n_token = output.generate_token_len |
| 188 | finish_reason = output.finish_reason |
| 189 | sess[index].tick(n_token) |
| 190 | sess[index].finish(_to_status(finish_reason)) |
| 191 | |
| 192 | profiler.finish() |
| 193 | |
| 194 | # report first failure |
| 195 | for i, s in enumerate(sess): |
| 196 | if s.status != Session.SUCCESS or s.ns[-1] < s.req_output_len: |
| 197 | logger.error(f'Request {i} failed with {s.ns[-1]}/{s.req_output_len} tokens generated' # noqa: E501 |
| 198 | ) |
| 199 | logger.error(f'Prompt: {prompts[i]}') |
| 200 | logger.warning('Got failed requests, metrics may be invalid') |
| 201 | break |
| 202 |
no test coverage detected