(self, initial_state: DialogueState, num_simulations: int)
| 90 | node = node.parent |
| 91 | |
| 92 | def search(self, initial_state: DialogueState, num_simulations: int) -> DialogueState: |
| 93 | logger.debug(f"Starting MCTS search with {num_simulations} simulations") |
| 94 | if not self.root: |
| 95 | self.root = MCTSNode(initial_state) |
| 96 | self.graph.add_node(id(self.root)) |
| 97 | self.node_labels[id(self.root)] = f"Root\nVisits: 0\nValue: 0.00" |
| 98 | logger.debug("Created root node") |
| 99 | |
| 100 | for i in range(num_simulations): |
| 101 | logger.debug(f"Starting simulation {i+1}") |
| 102 | node = self.select(self.root) |
| 103 | if not self.is_terminal(node.state): |
| 104 | node = self.expand(node) |
| 105 | value = self.simulate(node) |
| 106 | self.backpropagate(node, value) |
| 107 | |
| 108 | best_child = max(self.root.children, key=lambda c: c.visits) |
| 109 | logger.debug(f"Search complete. Best child node: Visits: {best_child.visits}, Value: {best_child.value}") |
| 110 | return best_child.state |
| 111 | |
| 112 | def generate_actions(self, state: DialogueState) -> List[str]: |
| 113 | logger.debug("Generating actions for current state") |
no test coverage detected