XiangJinYu's picture
add metagpt
fe5c39d verified
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong ([email protected])
# @Desc :
from __future__ import annotations
import asyncio
from typing import Any, List, Optional
from pydantic import BaseModel, ConfigDict, Field
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.strategy.base import ThoughtNode, ThoughtTree
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
from metagpt.utils.common import CodeParser
OUTPUT_FORMAT = """
Each output should be strictly a list of nodes, in json format, like this:
```json
[
{
"node_id": str = "unique identifier for a solution, can be an ordinal",
"node_state_instruction": "specified sample of solution",
},
...
]
```
"""
class ThoughtSolverBase(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
thought_tree: Optional[ThoughtTree] = Field(default=None)
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.llm.use_system_prompt = False
async def solve(self, init_prompt):
"""
Solve method for subclasses to implement.
"""
raise NotImplementedError("Subclasses must implement the solve method")
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
"""
Generate children thoughts based on the current state.
Args:
current_state (str): The current state for which thoughts are generated.
current_node (ThoughtNode): The current node in the thought tree.
Returns:
List[ThoughtNode]: List of nodes representing the generated thoughts.
"""
state_prompt = self.config.parser.propose(
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
)
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
thoughts = CodeParser.parse_code(block="", text=rsp)
thoughts = eval(thoughts)
# fixme 避免不跟随,生成过多nodes
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
return self.thought_tree.update_node(thoughts, current_node=current_node)
async def evaluate_node(self, node, parent_value) -> None:
"""
Evaluate a node and update its status and value.
Args:
node (ThoughtNode): The node to be evaluated.
parent_value (float): The parent node's value.
Returns:
None
"""
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
evaluation = await self.llm.aask(msg=eval_prompt)
value = self.config.evaluator(evaluation, **{"node_id": node.id})
status = self.config.evaluator.status_verify(value)
node.update_valid_status(status=status)
# 累计分数
node.update_value(parent_value + value)
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
"""
Select nodes based on the configured selection method.
Args:
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
Returns:
List[ThoughtNode]: List of selected nodes.
"""
# nodes to be selected
nodes = []
if self.config.method_select == MethodSelect.SAMPLE:
raise NotImplementedError
elif self.config.method_select == MethodSelect.GREEDY:
nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
for node in thought_nodes:
if node not in nodes:
node.parent = None # 从树中删除节点
return nodes
def update_solution(self):
"""
Select the result with the highest score.
Returns:
- List[ThoughtNode]: List of nodes representing the best solution.
- List[str]: List of node names forming the best solution path.
"""
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
best_solution_path = self.thought_tree.parse_node_path(best_node)
return [best_node], best_solution_path
class BFSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
"""
Solve the problem using Breadth-First Search (BFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through BFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
current_nodes = [root]
for step in range(self.config.max_steps):
solutions = await self._bfs_build(current_nodes)
selected_nodes = self.select_nodes(solutions)
current_nodes = selected_nodes
self.thought_tree.show()
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
async def _bfs_build(self, current_nodes):
"""
Build the thought tree using Breadth-First Search (BFS) strategy.
Args:
current_nodes (List[ThoughtNode]): Current nodes to expand.
Returns:
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
"""
tasks = []
for node in current_nodes:
current_state = self.config.parser(node.name)
current_value = node.value
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
thought_nodes_list = await asyncio.gather(*tasks)
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
return solutions
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await asyncio.gather(
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
)
return thought_nodes
class DFSSolver(ThoughtSolverBase):
async def _dfs(self, root_node):
"""
Perform Depth-First Search (DFS) on the thought tree.
Args:
root_node (ThoughtNode): The root node of the thought tree.
Returns:
List[str]: The solution path obtained through DFS.
"""
impossible_state_cnt = 0
node = root_node
for step in range(self.max_steps):
current_state = self.config.parser(node.name)
current_value = node.value
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
if thought_nodes[0].valid_status is False:
impossible_state_cnt += 1
if impossible_state_cnt >= 2:
logger.info("impossible state reached, break")
break
node = thought_nodes[0]
_solution_path = self.thought_tree.parse_node_path(node)
self.thought_tree.show()
return _solution_path
async def solve(self, init_prompt="", root=ThoughtNode("")):
"""
Solve the problem using Depth-First Search (DFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through DFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
for n in range(self.config.n_solution_sample):
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
await self._dfs(root)
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
class MCTSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
raise NotImplementedError
class TreeofThought(BaseModel):
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
strategy: Strategy = Field(default=Strategy.BFS)
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._initialize_solver(self.strategy)
def _initialize_solver(self, strategy):
"""
Initialize the solver based on the chosen strategy.
Args:
strategy (Strategy): The strategy to use for solving.
Returns:
ThoughtSolverBase: An instance of the appropriate solver.
"""
if strategy == Strategy.BFS:
self.solver = BFSSolver(config=self.config)
elif strategy == Strategy.DFS:
self.solver = DFSSolver(config=self.config)
elif strategy == Strategy.MCTS:
self.solver = MCTSSolver(config=self.config)
else:
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
async def solve(self, init_prompt=""):
"""
Solve the problem using the specified strategy.
Args:
init_prompt (str): The initial prompt for the solver.
strategy (str): The strategy to use for solving.
Returns:
Any: The solution obtained using the selected strategy.
"""
await self.solver.solve(init_prompt)