MCPcopy
hub / github.com/algorithmicsuperintelligence/optillm / RStar

Class RStar

optillm/rstar.py:25–350  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

23 self.value = 0.0
24
25class RStar:
26 def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4, request_config: dict = None, request_id: str = None):
27 self.client = client
28 self.model_name = model
29 self.max_depth = max_depth
30 self.num_rollouts = num_rollouts
31 self.c = c
32 self.actions = ["A1", "A2", "A3", "A4", "A5"]
33 self.original_question = None
34 self.system = system
35 self.rstar_completion_tokens = 0
36 self.request_id = request_id
37
38 # Extract max_tokens from request_config with default
39 self.max_tokens = 4096
40 if request_config:
41 self.max_tokens = request_config.get('max_tokens', self.max_tokens)
42
43 logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}")
44
45 async def generate_response_async(self, prompt: str) -> str:
46 return await asyncio.to_thread(self.generate_response, prompt)
47
48 async def expand_async(self, node: Node, action: str) -> Node:
49 prompt = self.create_prompt(node.state, action)
50 new_state = await self.generate_response_async(prompt)
51 child_node = Node(new_state, action, node)
52 node.children.append(child_node)
53 logger.debug(f"Expanded node with action: {action}")
54 return child_node
55
56 async def simulate_async(self, node: Node) -> float:
57 current_node = node
58 depth = 0
59 logger.debug("Starting simulation")
60 while depth < self.max_depth:
61 if not current_node.children:
62 action = random.choice(self.actions)
63 current_node = await self.expand_async(current_node, action)
64 else:
65 current_node = random.choice(current_node.children)
66 depth += 1
67 value = self.evaluate(current_node)
68 logger.debug(f"Simulation complete. Final value: {value}")
69 return value
70
71 async def mcts_async(self, root_state: str) -> List[Node]:
72 root = Node(root_state, None)
73 tasks = []
74 for _ in range(self.num_rollouts):
75 tasks.append(self.mcts_rollout_async(root))
76 await asyncio.gather(*tasks)
77 return self.extract_trajectories(root)
78
79 async def mcts_rollout_async(self, root: Node):
80 node = root
81 while node.children:
82 node, _ = self.select_action(node)

Callers 4

test.pyFile · 0.90
execute_single_approachFunction · 0.90
runFunction · 0.90

Calls

no outgoing calls