Spaces:
Running
Running
File size: 6,631 Bytes
926675f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# main class chaining Planner, Worker and Solver.
import re
import time
from nodes.Planner import Planner
from nodes.Solver import Solver
from nodes.Worker import *
from utils.util import *
class PWS:
def __init__(self, available_tools=["Google", "LLM"], fewshot="\n", planner_model="text-davinci-003",
solver_model="text-davinci-003"):
self.workers = available_tools
self.planner = Planner(workers=self.workers,
model_name=planner_model,
fewshot=fewshot)
self.solver = Solver(model_name=solver_model)
self.plans = []
self.planner_evidences = {}
self.worker_evidences = {}
self.tool_counter = {}
self.planner_token_unit_price = get_token_unit_price(planner_model)
self.solver_token_unit_price = get_token_unit_price(solver_model)
self.tool_token_unit_price = get_token_unit_price("text-davinci-003")
self.google_unit_price = 0.01
# input: the question line. e.g. "Question: What is the capital of France?"
def run(self, input):
# run is stateless, so we need to reset the evidences
self._reinitialize()
result = {}
st = time.time()
# Plan
planner_response = self.planner.run(input, log=True)
plan = planner_response["output"]
planner_log = planner_response["input"] + planner_response["output"]
self.plans = self._parse_plans(plan)
self.planner_evidences = self._parse_planner_evidences(plan)
#assert len(self.plans) == len(self.planner_evidences)
# Work
self._get_worker_evidences()
worker_log = ""
for i in range(len(self.plans)):
e = f"#E{i + 1}"
worker_log += f"{self.plans[i]}\nEvidence:\n{self.worker_evidences[e]}\n"
# Solve
solver_response = self.solver.run(input, worker_log, log=True)
output = solver_response["output"]
solver_log = solver_response["input"] + solver_response["output"]
result["wall_time"] = time.time() - st
result["input"] = input
result["output"] = output
result["planner_log"] = planner_log
result["worker_log"] = worker_log
result["solver_log"] = solver_log
result["tool_usage"] = self.tool_counter
result["steps"] = len(self.plans) + 1
result["total_tokens"] = planner_response["prompt_tokens"] + planner_response["completion_tokens"] \
+ solver_response["prompt_tokens"] + solver_response["completion_tokens"] \
+ self.tool_counter.get("LLM_token", 0) \
+ self.tool_counter.get("Calculator_token", 0)
result["token_cost"] = self.planner_token_unit_price * (planner_response["prompt_tokens"] + planner_response["completion_tokens"]) \
+ self.solver_token_unit_price * (solver_response["prompt_tokens"] + solver_response["completion_tokens"]) \
+ self.tool_token_unit_price * (self.tool_counter.get("LLM_token", 0) + self.tool_counter.get("Calculator_token", 0))
result["tool_cost"] = self.tool_counter.get("Google", 0) * self.google_unit_price
result["total_cost"] = result["token_cost"] + result["tool_cost"]
return result
def _parse_plans(self, response):
plans = []
for line in response.splitlines():
if line.startswith("Plan:"):
plans.append(line)
return plans
def _parse_planner_evidences(self, response):
evidences = {}
for line in response.splitlines():
if line.startswith("#") and line[1] == "E" and line[2].isdigit():
e, tool_call = line.split("=", 1)
e, tool_call = e.strip(), tool_call.strip()
if len(e) == 3:
evidences[e] = tool_call
else:
evidences[e] = "No evidence found"
return evidences
# use planner evidences to assign tasks to respective workers.
def _get_worker_evidences(self):
for e, tool_call in self.planner_evidences.items():
if "[" not in tool_call:
self.worker_evidences[e] = tool_call
continue
tool, tool_input = tool_call.split("[", 1)
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input):
if var in self.worker_evidences:
tool_input = tool_input.replace(var, "[" + self.worker_evidences[var] + "]")
if tool in self.workers:
self.worker_evidences[e] = WORKER_REGISTRY[tool].run(tool_input)
if tool == "Google":
self.tool_counter["Google"] = self.tool_counter.get("Google", 0) + 1 # number of query
elif tool == "LLM":
self.tool_counter["LLM_token"] = self.tool_counter.get("LLM_token", 0) + len(
tool_input + self.worker_evidences[e]) // 4
elif tool == "Calculator":
self.tool_counter["Calculator_token"] = self.tool_counter.get("Calculator_token", 0) \
+ len(
LLMMathChain(llm=OpenAI(), verbose=False).prompt.template + tool_input + self.worker_evidences[
e]) // 4
else:
self.worker_evidences[e] = "No evidence found"
def _reinitialize(self):
self.plans = []
self.planner_evidences = {}
self.worker_evidences = {}
self.tool_counter = {}
class PWS_Base(PWS):
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_BASE, planner_model="text-davinci-003",
solver_model="text-davinci-003", available_tools=["Wikipedia", "LLM"]):
super().__init__(available_tools=available_tools,
fewshot=fewshot,
planner_model=planner_model,
solver_model=solver_model)
class PWS_Extra(PWS):
def __init__(self, fewshot=fewshots.HOTPOTQA_PWS_EXTRA, planner_model="text-davinci-003",
solver_model="text-davinci-003", available_tools=["Google", "Calculator", "LLM"]):
super().__init__(available_tools=available_tools,
fewshot=fewshot,
planner_model=planner_model,
solver_model=solver_model)
|