File size: 10,017 Bytes
fe5c39d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# -*- coding: utf-8 -*-
# @Date    : 12/23/2023 4:51 PM
# @Author  : stellahong ([email protected])
# @Desc    :
from __future__ import annotations

import asyncio
from typing import Any, List, Optional

from pydantic import BaseModel, ConfigDict, Field

from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.strategy.base import ThoughtNode, ThoughtTree
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
from metagpt.utils.common import CodeParser

OUTPUT_FORMAT = """
Each output should be strictly a list of nodes, in json format, like this:
```json
    [
        {
            "node_id": str = "unique identifier for a solution, can be an ordinal",
            "node_state_instruction": "specified sample of solution",
        },
        ...
    ]
```
"""


class ThoughtSolverBase(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    thought_tree: Optional[ThoughtTree] = Field(default=None)
    llm: BaseLLM = Field(default_factory=LLM, exclude=True)
    config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)
        self.llm.use_system_prompt = False

    async def solve(self, init_prompt):
        """
        Solve method for subclasses to implement.
        """
        raise NotImplementedError("Subclasses must implement the solve method")

    async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
        """
        Generate children thoughts based on the current state.

        Args:
            current_state (str): The current state for which thoughts are generated.
            current_node (ThoughtNode): The current node in the thought tree.

        Returns:
            List[ThoughtNode]: List of nodes representing the generated thoughts.
        """
        state_prompt = self.config.parser.propose(
            current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
        )
        rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
        thoughts = CodeParser.parse_code(block="", text=rsp)
        thoughts = eval(thoughts)
        # fixme 避免不跟随,生成过多nodes
        # valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
        return self.thought_tree.update_node(thoughts, current_node=current_node)

    async def evaluate_node(self, node, parent_value) -> None:
        """
        Evaluate a node and update its status and value.

        Args:
            node (ThoughtNode): The node to be evaluated.
            parent_value (float): The parent node's value.

        Returns:
            None
        """
        eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
        evaluation = await self.llm.aask(msg=eval_prompt)

        value = self.config.evaluator(evaluation, **{"node_id": node.id})
        status = self.config.evaluator.status_verify(value)

        node.update_valid_status(status=status)
        # 累计分数
        node.update_value(parent_value + value)

    def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
        """
        Select nodes based on the configured selection method.

        Args:
            thought_nodes (List[ThoughtNode]): List of nodes to be selected.

        Returns:
            List[ThoughtNode]: List of selected nodes.
        """
        # nodes to be selected
        nodes = []
        if self.config.method_select == MethodSelect.SAMPLE:
            raise NotImplementedError
        elif self.config.method_select == MethodSelect.GREEDY:
            nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
        for node in thought_nodes:
            if node not in nodes:
                node.parent = None  # 从树中删除节点
        return nodes

    def update_solution(self):
        """
        Select the result with the highest score.

        Returns:
            - List[ThoughtNode]: List of nodes representing the best solution.
            - List[str]: List of node names forming the best solution path.
        """
        best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
        best_solution_path = self.thought_tree.parse_node_path(best_node)
        return [best_node], best_solution_path


class BFSSolver(ThoughtSolverBase):
    async def solve(self, init_prompt=""):
        """
        Solve the problem using Breadth-First Search (BFS) strategy.

        Args:
            init_prompt (str): The initial prompt for the solver.

        Returns:
            List[str]: The best solution path obtained through BFS.
        """
        root = ThoughtNode(init_prompt)
        self.thought_tree = ThoughtTree(root)
        current_nodes = [root]
        for step in range(self.config.max_steps):
            solutions = await self._bfs_build(current_nodes)

            selected_nodes = self.select_nodes(solutions)
            current_nodes = selected_nodes

            self.thought_tree.show()

        best_solution, best_solution_path = self.update_solution()
        logger.info(f"best solution is: {best_solution_path}")
        return best_solution_path

    async def _bfs_build(self, current_nodes):
        """
        Build the thought tree using Breadth-First Search (BFS) strategy.

        Args:
            current_nodes (List[ThoughtNode]): Current nodes to expand.

        Returns:
            List[ThoughtNode]: The solutions obtained after expanding the current nodes.
        """
        tasks = []
        for node in current_nodes:
            current_state = self.config.parser(node.name)
            current_value = node.value
            tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))

        thought_nodes_list = await asyncio.gather(*tasks)
        solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
        return solutions

    async def generate_and_evaluate_nodes(self, current_state, current_value, node):
        thought_nodes = await self.generate_thoughts(current_state, current_node=node)
        await asyncio.gather(
            *(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
        )
        return thought_nodes


class DFSSolver(ThoughtSolverBase):
    async def _dfs(self, root_node):
        """
        Perform Depth-First Search (DFS) on the thought tree.

        Args:
            root_node (ThoughtNode): The root node of the thought tree.

        Returns:
            List[str]: The solution path obtained through DFS.
        """
        impossible_state_cnt = 0
        node = root_node
        for step in range(self.max_steps):
            current_state = self.config.parser(node.name)
            current_value = node.value
            thought_nodes = await self.generate_thoughts(current_state, current_node=node)
            await self.evaluate_node(thought_nodes[0], parent_value=current_value)
            if thought_nodes[0].valid_status is False:
                impossible_state_cnt += 1
            if impossible_state_cnt >= 2:
                logger.info("impossible state reached, break")
                break
            node = thought_nodes[0]
        _solution_path = self.thought_tree.parse_node_path(node)
        self.thought_tree.show()

        return _solution_path

    async def solve(self, init_prompt="", root=ThoughtNode("")):
        """
        Solve the problem using Depth-First Search (DFS) strategy.

        Args:
            init_prompt (str): The initial prompt for the solver.

        Returns:
            List[str]: The best solution path obtained through DFS.
        """
        root = ThoughtNode(init_prompt)
        self.thought_tree = ThoughtTree(root)
        for n in range(self.config.n_solution_sample):
            # fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
            await self._dfs(root)

        best_solution, best_solution_path = self.update_solution()
        logger.info(f"best solution is: {best_solution_path}")
        return best_solution_path


class MCTSSolver(ThoughtSolverBase):
    async def solve(self, init_prompt=""):
        raise NotImplementedError


class TreeofThought(BaseModel):
    config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
    solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
    strategy: Strategy = Field(default=Strategy.BFS)

    class Config:
        arbitrary_types_allowed = True

    def __init__(self, **kwargs: Any):
        super().__init__(**kwargs)
        self._initialize_solver(self.strategy)

    def _initialize_solver(self, strategy):
        """
        Initialize the solver based on the chosen strategy.

        Args:
            strategy (Strategy): The strategy to use for solving.

        Returns:
            ThoughtSolverBase: An instance of the appropriate solver.
        """
        if strategy == Strategy.BFS:
            self.solver = BFSSolver(config=self.config)
        elif strategy == Strategy.DFS:
            self.solver = DFSSolver(config=self.config)
        elif strategy == Strategy.MCTS:
            self.solver = MCTSSolver(config=self.config)
        else:
            raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")

    async def solve(self, init_prompt=""):
        """
        Solve the problem using the specified strategy.

        Args:
            init_prompt (str): The initial prompt for the solver.
            strategy (str): The strategy to use for solving.

        Returns:
            Any: The solution obtained using the selected strategy.
        """
        await self.solver.solve(init_prompt)