Kate Forsberg
removed rule
727c6b5
raw
history blame
6.03 kB
import glob
import gradio as gr
from typing import Any
from dotenv import load_dotenv
from griptape.structures import Agent
from griptape.tasks import PromptTask
from griptape.drivers import (
LocalConversationMemoryDriver,
GriptapeCloudStructureRunDriver,
LocalFileManagerDriver,
LocalStructureRunDriver,
)
from griptape.memory.structure import ConversationMemory
from griptape.tools import StructureRunClient, FileManager
from griptape.rules import Rule, Ruleset
from griptape.config import AnthropicStructureConfig
import time
import os
import re
# Load environment variables
load_dotenv()
# Create an agent that will create a prompt that can be used as input for the query agent from the Griptape Cloud.
# Function that logs user history - adds to history parameter of Gradio
# TODO: Figure out the exact use of this function
def user(user_message, history):
history.append([user_message, None])
return ("", history)
# Function that logs bot history - adds to the history parameter of Gradio
# TODO: Figure out the exact use of this function
def bot(history):
response = send_message(history[-1][0])
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.005)
yield history
def create_prompt_task(session_id: str, message: str) -> PromptTask:
return PromptTask(
f"""
Re-structure the values from the user's questions: '{message}' and the input value from the conversation memory '{session_id}.json' to fit the following format. Leave out attributes that aren't important to the user:
years experience: <x>
location: <x>
role: <x>
skills: <x>
expected salary: <x>
availability: <x>
past companies: <x>
past projects: <x>
show reel details: <x>
""",
)
def build_talk_agent(session_id: str, message: str) -> Agent:
ruleset = Ruleset(
name="Local Gradio Agent",
rules=[
Rule(
value="You are responsible for structuring a user's questions into a specific format for a query."
),
Rule(
value="You ask the user follow-up questions to fill in missing information for the format you are trying to fit."
),
Rule(
value="If the user has no preference for a specific attribute, then you can remove it from the query."
),
Rule(
value="Only return the current query structure and any questions to fill in missing information."
),
],
)
file_manager_tool = FileManager(
name="FileManager",
file_manager_driver=LocalFileManagerDriver(),
off_prompt=False,
)
return Agent(
config=AnthropicStructureConfig(),
conversation_memory=ConversationMemory(
driver=LocalConversationMemoryDriver(file_path=f"{session_id}.json")
),
tools=[file_manager_tool],
tasks=[create_prompt_task(session_id, message)],
rulesets=[ruleset],
)
# Creates an agent for each run
# The agent uses local memory, which it differentiates between by session_hash.
def build_agent(session_id: str, message: str) -> Agent:
ruleset = Ruleset(
name="Local Gradio Agent",
rules=[
Rule(
value="You are responsible for structuring a user's questions into a specific format for a query and then querying."
),
Rule(
value="Only return the result of the query, do not provide additional commentary."
),
Rule(value="Only perform one task at a time."),
Rule(
value="Do not perform the query unless the user has said 'Done' with formulating."
),
Rule(
value="Only perform the query with the proper query structure as one string argument."
),
Rule(
value="If you reformulate the query, then you must ask the user if they are 'Done' again."
),
Rule(
value="If the user says they want to start over, then you must delete the conversation memory file."
),
],
)
print("Base URL", os.environ.get("BASE_URL", "https://cloud.griptape.ai"))
query_client = StructureRunClient(
name="QueryResumeSearcher",
description="Use it to search for a candidate with the query.",
driver=GriptapeCloudStructureRunDriver(
# base_url=os.environ.get("BASE_URL","https://cloud.griptape.ai"),
structure_id=os.getenv("GT_STRUCTURE_ID"),
api_key=os.getenv("GT_CLOUD_API_KEY"),
structure_run_wait_time_interval=5,
structure_run_max_wait_time_attempts=30,
),
)
talk_client = StructureRunClient(
name="FormulateQueryFromUser",
description="Used to formulate a query from the user's input.",
driver=LocalStructureRunDriver(
structure_factory_fn=lambda: build_talk_agent(session_id, message),
),
)
return Agent(
config=AnthropicStructureConfig(),
conversation_memory=ConversationMemory(
driver=LocalConversationMemoryDriver(file_path=f"{session_id}.json")
),
tools=[talk_client, query_client],
rulesets=[ruleset],
)
def delete_json(session_id: str) -> None:
for file in glob.glob(f"{session_id}.json"):
os.remove(file)
def send_message(message: str, history, request: gr.Request) -> Any:
if request:
session_hash = request.session_hash
agent = build_agent(session_hash, message)
response = agent.run(message)
# if re.search(r'\bdone[.,!?]?\b', message, re.IGNORECASE):
# delete_json(session_hash)
return response.output.value
demo = gr.ChatInterface(fn=send_message)
demo.launch()