Kate Forsberg
dependenceis
60cb3bc
raw
history blame
6.8 kB
import glob
from venv import create
import gradio as gr
from typing import Any
from dotenv import load_dotenv
import requests
from griptape.structures import Agent
from griptape.tasks import PromptTask
from griptape.drivers import (
LocalConversationMemoryDriver,
GriptapeCloudStructureRunDriver,
LocalFileManagerDriver,
LocalStructureRunDriver,
GriptapeCloudConversationMemoryDriver,
)
from griptape.memory.structure import ConversationMemory
from griptape.tools import StructureRunTool, FileManagerTool
from griptape.rules import Rule, Ruleset
from griptape.configs.drivers import AnthropicDriversConfig
from griptape.configs import Defaults
import time
import os
from urllib.parse import urljoin
# Load environment variables
load_dotenv()
Defaults.drivers_config = AnthropicDriversConfig()
base_url = "https://cloud.griptape.ai"
headers_api = {
"Authorization": f"Bearer {os.environ['GT_CLOUD_API_KEY']}",
"Content-Type": "application/json",
}
threads = {}
# custom_css = """
# #component-2 {
# height: 75vh !important;
# min-height: 600px !important;
# """
def create_thread_id(session_id: str) -> str:
if not session_id in threads:
params = {
"name": session_id,
"messages": [],
}
response = requests.post(
url=urljoin(base_url, "/api/threads"), headers=headers_api, json=params
)
response.raise_for_status()
thread_id = response.json()["thread_id"]
threads[session_id] = thread_id
return thread_id
else:
return threads[session_id]
# Create an agent that will create a prompt that can be used as input for the query agent from the Griptape Cloud.
# Function that logs user history - adds to history parameter of Gradio
# TODO: Figure out the exact use of this function
def user(user_message, history):
history.append([user_message, None])
return ("", history)
# Function that logs bot history - adds to the history parameter of Gradio
# TODO: Figure out the exact use of this function
def bot(history):
response = send_message(history[-1][0])
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.005)
yield history
def create_prompt_task(session_id: str, message: str) -> PromptTask:
return PromptTask(
f"""
Re-structure the values to form a query from the user's questions: '{message}' and the input value from the conversation memory. Leave out attributes that aren't important to the user:
""",
)
def build_talk_agent(session_id: str, message: str) -> Agent:
create_thread_id(session_id)
ruleset = Ruleset(
name="Local Gradio Agent",
rules=[
Rule(
value="You are responsible for structuring a user's questions into a specific format for a query."
),
Rule(
value="""You ask the user follow-up questions to fill in missing information for:
years experience,
location,
role,
skills,
expected salary,
availability,
past companies,
past projects,
show reel details
"""
),
Rule(
value="Return the current query structure and any questions to fill in missing information."
),
],
)
return Agent(
conversation_memory=ConversationMemory(
conversation_memory_driver=GriptapeCloudConversationMemoryDriver(
thread_id=threads[session_id],
)
),
tasks=[create_prompt_task(session_id, message)],
rulesets=[ruleset],
)
# Creates an agent for each run
# The agent uses local memory, which it differentiates between by session_hash.
def build_agent(session_id: str, message: str, kbs:str) -> Agent:
create_thread_id(session_id)
ruleset = Ruleset(
name="Local Gradio Agent",
rules=[
Rule(
value="You are responsible for structuring a user's questions into a query and then querying."
),
Rule(
value="Only return the result of the query, do not provide additional commentary."
),
Rule(value="Only perform one task at a time."),
Rule(
value="Do not perform the query unless the user has said 'Done' with formulating."
),
Rule(
value="Only perform the query as one string argument."
),
Rule(
value="If you reformulate the query, then you must ask the user if they are 'Done' again."
),
Rule(
value="If the user says they want to start over, then you must delete the conversation memory file."
),
],
)
query_client = StructureRunTool(
name="QueryResumeSearcher",
description=f"""Use it to search for a candidate with the query.
Add this as another argument after the input: {kbs}
""",
driver=GriptapeCloudStructureRunDriver(
structure_id=os.getenv("GT_STRUCTURE_ID"),
api_key=os.getenv("GT_CLOUD_API_KEY"),
structure_run_wait_time_interval=3,
structure_run_max_wait_time_attempts=30,
),
)
talk_client = StructureRunTool(
name="FormulateQueryFromUser",
description="Used to formulate a query from the user's input.",
driver=LocalStructureRunDriver(
structure_factory_fn=lambda: build_talk_agent(session_id, message),
),
)
return Agent(
conversation_memory=ConversationMemory(
conversation_memory_driver=GriptapeCloudConversationMemoryDriver(
thread_id=threads[session_id],
)
),
tools=[talk_client, query_client],
rulesets=[ruleset],
)
def send_message(message: str, history, knowledge_bases, request: gr.Request) -> Any:
if request:
session_hash = request.session_hash
agent = build_agent(session_hash, message, str(knowledge_bases))
response = agent.run(message)
return response.output.value
with gr.Blocks() as demo:
knowledge_bases = gr.CheckboxGroup(label="Select Knowledge Bases", choices=["skills","demographics","linked_in","showreels"])
chatbot = gr.ChatInterface(fn=send_message, chatbot=gr.Chatbot(height=300),additional_inputs=knowledge_bases)
demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD")))
# demo.launch(share=True)
# Set it back to empty when a session is done
# Is there a better way?
threads = {}