Collect useful stats about the run. Add more stats as needed.
| 42 | |
| 43 | |
| 44 | class AsyncStatsCallbackHandler(AsyncCallbackHandler): |
| 45 | """Collect useful stats about the run. |
| 46 | Add more stats as needed.""" |
| 47 | |
| 48 | def __init__(self, stream: bool = False) -> None: |
| 49 | super().__init__() |
| 50 | self.cnt = 0 |
| 51 | self.input_tokens = 0 |
| 52 | self.output_tokens = 0 |
| 53 | # same for gpt-3.5 |
| 54 | self.encoder = tiktoken.encoding_for_model("gpt-4") |
| 55 | self.stream = stream |
| 56 | self.all_times = [] |
| 57 | self.additional_fields = {} |
| 58 | self.start_time = 0 |
| 59 | |
| 60 | async def on_chat_model_start(self, serialized, prompts, **kwargs): |
| 61 | self.start_time = time.time() |
| 62 | if self.stream: |
| 63 | # if streaming mode, on_llm_end response is not collected |
| 64 | # therefore, we need to count input token based on the |
| 65 | # prompt length at the beginning |
| 66 | self.cnt += 1 |
| 67 | self.input_tokens += len(self.encoder.encode(prompts[0][0].content)) |
| 68 | |
| 69 | async def on_llm_new_token(self, token, *args, **kwargs): |
| 70 | if self.stream: |
| 71 | # if streaming mode, on_llm_end response is not collected |
| 72 | # therefore, we need to manually count output token based on the |
| 73 | # number of streamed out tokens |
| 74 | self.output_tokens += 1 |
| 75 | |
| 76 | async def on_llm_end(self, response, *args, **kwargs): |
| 77 | self.all_times.append(round(time.time() - self.start_time, 2)) |
| 78 | if not self.stream: |
| 79 | # if not streaming mode, on_llm_end response is collected |
| 80 | # so we can use this stats directly |
| 81 | token_usage = response.llm_output["token_usage"] |
| 82 | self.input_tokens += token_usage["prompt_tokens"] |
| 83 | self.output_tokens += token_usage["completion_tokens"] |
| 84 | self.cnt += 1 |
| 85 | |
| 86 | def reset(self) -> None: |
| 87 | self.cnt = 0 |
| 88 | self.input_tokens = 0 |
| 89 | self.output_tokens = 0 |
| 90 | self.all_times = [] |
| 91 | self.additional_fields = {} |
| 92 | |
| 93 | def get_stats(self) -> dict[str, int]: |
| 94 | return { |
| 95 | "calls": self.cnt, |
| 96 | "input_tokens": self.input_tokens, |
| 97 | "output_tokens": self.output_tokens, |
| 98 | "all_times": self.all_times, |
| 99 | **self.additional_fields, |
| 100 | } |