Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.
| 707 | |
| 708 | |
| 709 | class ClaudeModel(dspy.dsp.modules.lm.LM): |
| 710 | """Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.""" |
| 711 | |
| 712 | def __init__( |
| 713 | self, |
| 714 | model: str, |
| 715 | api_key: Optional[str] = None, |
| 716 | api_base: Optional[str] = None, |
| 717 | **kwargs, |
| 718 | ): |
| 719 | super().__init__(model) |
| 720 | try: |
| 721 | from anthropic import Anthropic |
| 722 | except ImportError as err: |
| 723 | raise ImportError("Claude requires `pip install anthropic`.") from err |
| 724 | |
| 725 | self.provider = "anthropic" |
| 726 | self.api_key = api_key = ( |
| 727 | os.environ.get("ANTHROPIC_API_KEY") if api_key is None else api_key |
| 728 | ) |
| 729 | self.api_base = ( |
| 730 | "https://api.anthropic.com/v1/messages" if api_base is None else api_base |
| 731 | ) |
| 732 | self.kwargs = { |
| 733 | "temperature": kwargs.get("temperature", 0.0), |
| 734 | "max_tokens": min(kwargs.get("max_tokens", 4096), 4096), |
| 735 | "top_p": kwargs.get("top_p", 1.0), |
| 736 | "top_k": kwargs.get("top_k", 1), |
| 737 | "n": kwargs.pop("n", kwargs.pop("num_generations", 1)), |
| 738 | **kwargs, |
| 739 | "model": model, |
| 740 | } |
| 741 | self.history: list[dict[str, Any]] = [] |
| 742 | self.client = Anthropic(api_key=api_key) |
| 743 | self.model = model |
| 744 | |
| 745 | self._token_usage_lock = threading.Lock() |
| 746 | self.prompt_tokens = 0 |
| 747 | self.completion_tokens = 0 |
| 748 | |
| 749 | def log_usage(self, response): |
| 750 | """Log the total tokens from the Anthropic API response.""" |
| 751 | usage_data = response.usage |
| 752 | if usage_data: |
| 753 | with self._token_usage_lock: |
| 754 | self.prompt_tokens += usage_data.input_tokens |
| 755 | self.completion_tokens += usage_data.output_tokens |
| 756 | |
| 757 | def get_usage_and_reset(self): |
| 758 | """Get the total tokens used and reset the token usage.""" |
| 759 | usage = { |
| 760 | self.model: { |
| 761 | "prompt_tokens": self.prompt_tokens, |
| 762 | "completion_tokens": self.completion_tokens, |
| 763 | } |
| 764 | } |
| 765 | self.prompt_tokens = 0 |
| 766 | self.completion_tokens = 0 |