| 23 | self.value = 0.0 |
| 24 | |
| 25 | class RStar: |
| 26 | def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4, request_config: dict = None, request_id: str = None): |
| 27 | self.client = client |
| 28 | self.model_name = model |
| 29 | self.max_depth = max_depth |
| 30 | self.num_rollouts = num_rollouts |
| 31 | self.c = c |
| 32 | self.actions = ["A1", "A2", "A3", "A4", "A5"] |
| 33 | self.original_question = None |
| 34 | self.system = system |
| 35 | self.rstar_completion_tokens = 0 |
| 36 | self.request_id = request_id |
| 37 | |
| 38 | # Extract max_tokens from request_config with default |
| 39 | self.max_tokens = 4096 |
| 40 | if request_config: |
| 41 | self.max_tokens = request_config.get('max_tokens', self.max_tokens) |
| 42 | |
| 43 | logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}") |
| 44 | |
| 45 | async def generate_response_async(self, prompt: str) -> str: |
| 46 | return await asyncio.to_thread(self.generate_response, prompt) |
| 47 | |
| 48 | async def expand_async(self, node: Node, action: str) -> Node: |
| 49 | prompt = self.create_prompt(node.state, action) |
| 50 | new_state = await self.generate_response_async(prompt) |
| 51 | child_node = Node(new_state, action, node) |
| 52 | node.children.append(child_node) |
| 53 | logger.debug(f"Expanded node with action: {action}") |
| 54 | return child_node |
| 55 | |
| 56 | async def simulate_async(self, node: Node) -> float: |
| 57 | current_node = node |
| 58 | depth = 0 |
| 59 | logger.debug("Starting simulation") |
| 60 | while depth < self.max_depth: |
| 61 | if not current_node.children: |
| 62 | action = random.choice(self.actions) |
| 63 | current_node = await self.expand_async(current_node, action) |
| 64 | else: |
| 65 | current_node = random.choice(current_node.children) |
| 66 | depth += 1 |
| 67 | value = self.evaluate(current_node) |
| 68 | logger.debug(f"Simulation complete. Final value: {value}") |
| 69 | return value |
| 70 | |
| 71 | async def mcts_async(self, root_state: str) -> List[Node]: |
| 72 | root = Node(root_state, None) |
| 73 | tasks = [] |
| 74 | for _ in range(self.num_rollouts): |
| 75 | tasks.append(self.mcts_rollout_async(root)) |
| 76 | await asyncio.gather(*tasks) |
| 77 | return self.extract_trajectories(root) |
| 78 | |
| 79 | async def mcts_rollout_async(self, root: Node): |
| 80 | node = root |
| 81 | while node.children: |
| 82 | node, _ = self.select_action(node) |
no outgoing calls