kateforsberg's picture
re-add username anad password
e43e145
raw
history blame
8.85 kB
import glob
from venv import create
import gradio as gr
from typing import Any
from dotenv import load_dotenv
import requests
from griptape.structures import Agent, Structure, Workflow
from griptape.tasks import PromptTask, StructureRunTask
from griptape.drivers import (
LocalConversationMemoryDriver,
GriptapeCloudStructureRunDriver,
LocalFileManagerDriver,
LocalStructureRunDriver,
GriptapeCloudConversationMemoryDriver,
)
from griptape.memory.structure import ConversationMemory
from griptape.tools import StructureRunTool, FileManagerTool
from griptape.rules import Rule, Ruleset
from griptape.configs.drivers import AnthropicDriversConfig
from griptape.configs import Defaults
import time
import os
from urllib.parse import urljoin
# Load environment variables
load_dotenv()
Defaults.drivers_config = AnthropicDriversConfig()
base_url = "https://cloud.griptape.ai"
headers_api = {
"Authorization": f"Bearer {os.environ['GT_CLOUD_API_KEY']}",
"Content-Type": "application/json",
}
threads = {}
# custom_css = """
# #component-2 {
# height: 75vh !important;
# min-height: 600px !important;
# """
def create_thread_id(session_id: str) -> str:
if not session_id in threads:
params = {
"name": session_id,
"messages": [],
}
response = requests.post(
url=urljoin(base_url, "/api/threads"), headers=headers_api, json=params
)
response.raise_for_status()
thread_id = response.json()["thread_id"]
threads[session_id] = thread_id
return thread_id
else:
return threads[session_id]
# Create an agent that will create a prompt that can be used as input for the query agent from the Griptape Cloud.
# Function that logs user history - adds to history parameter of Gradio
# TODO: Figure out the exact use of this function
def user(user_message, history):
history.append([user_message, None])
return ("", history)
# Function that logs bot history - adds to the history parameter of Gradio
# TODO: Figure out the exact use of this function
def bot(history):
response = send_message(history[-1][0])
history[-1][1] = ""
for character in response:
history[-1][1] += character
time.sleep(0.005)
yield history
def create_prompt_task(session_id: str, message: str) -> PromptTask:
return PromptTask(
f"""
Re-structure the values to form a query from the user's questions: '{message}' and the input value from the conversation memory. Leave out attributes that aren't important to the user:
""",
)
def build_talk_agent(session_id: str, message: str) -> Agent:
create_thread_id(session_id)
ruleset = Ruleset(
name="Local Gradio Agent",
rules=[
Rule(
value="You are responsible for structuring a user's questions into a specific format for a query."
),
Rule(
value="""You ask the user follow-up questions to fill in missing information for:
years experience,
location,
role,
skills,
expected salary,
availability,
past companies,
past projects,
show reel details
"""
),
Rule(
value="Return the current query structure and any questions to fill in missing information."
),
],
)
return Agent(
conversation_memory=ConversationMemory(
conversation_memory_driver=GriptapeCloudConversationMemoryDriver(
thread_id=threads[session_id],
)
),
tasks=[create_prompt_task(session_id, message)],
rulesets=[ruleset],
)
# Creates an agent for each run
# The agent uses local memory, which it differentiates between by session_hash.
def build_agent(session_id: str, message: str, kbs: str) -> Agent:
create_thread_id(session_id)
ruleset = Ruleset(
name="Local Gradio Agent",
rules=[
Rule(
value="You are responsible for structuring a user's questions into a query and then querying."
),
Rule(
value="Only return the result of the query, do not provide additional commentary."
),
Rule(value="Only perform one task at a time."),
Rule(
value="Do not perform the query unless the user has confirmed they are done with formulating."
),
Rule(value="Only perform the query as one string argument."),
Rule(
value="If the user says they want to start over, then you must delete the conversation memory file."
),
Rule(
value="Do not ever search conversation memory for a formulated query instead of querying. Query every time."
),
],
)
query_client = StructureRunTool(
name="QueryResumeSearcher",
description=f"""Use it to search for a candidate with the query. Add each item in this list as separate arguments:{kbs}. Do not add any other arguments.""",
driver=GriptapeCloudStructureRunDriver(
structure_id=os.getenv("GT_STRUCTURE_ID"),
api_key=os.getenv("GT_CLOUD_API_KEY"),
structure_run_wait_time_interval=3,
structure_run_max_wait_time_attempts=30,
),
# structure_run_driver = LocalStructureRunDriver(
# create_structure=create_structure
# )
)
talk_client = StructureRunTool(
name="FormulateQueryFromUser",
description="Used to formulate a query from the user's input.",
structure_run_driver=LocalStructureRunDriver(
create_structure=lambda: build_talk_agent(session_id, message),
),
)
return Agent(
conversation_memory=ConversationMemory(
conversation_memory_driver=GriptapeCloudConversationMemoryDriver(
thread_id=threads[session_id],
)
),
tools=[talk_client, query_client],
rulesets=[ruleset],
)
def send_message(message: str, history, knowledge_bases, request: gr.Request) -> Any:
if request:
session_hash = request.session_hash
agent = build_agent(session_hash, message, str(knowledge_bases))
response = agent.run(message)
return response.output.value
def send_message_call(message: str, history, knowledge_bases) -> Any:
structure_id = os.getenv("GT_STRUCTURE_ID")
api_key = os.getenv("GT_CLOUD_API_KEY")
structure_url = f"https://cloud.griptape.ai/api/structures/{structure_id}/runs"
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {"args": [message, *knowledge_bases]}
response = requests.post(structure_url, headers=headers, json=payload)
response.raise_for_status()
if response.status_code == 201:
data = response.json()
structure_run_id = data["structure_run_id"]
output = poll_structure(structure_run_id, headers)
return output["output_task_output"]["value"]
else:
return "Assistant Call Failed"
def poll_for_events(offset: int, structure_run_id: str, headers: dict):
url = f"https://cloud.griptape.ai/api/structure-runs/{structure_run_id}/events"
response = requests.get(
url=url, headers=headers, params={"offset": offset, "limit": 100}
)
response.raise_for_status()
return response
def poll_structure(structure_run_id: str, headers: dict):
response = poll_for_events(0, structure_run_id, headers)
events = response.json()["events"]
offset = response.json()["next_offset"]
not_finished = True
output = ""
while not_finished:
time.sleep(0.5)
for event in events:
if event["type"] == "FinishStructureRunEvent":
not_finished = False
output = dict(event["payload"])
break
response = response = poll_for_events(offset, structure_run_id, headers)
response.raise_for_status()
events = response.json()["events"]
offset = response.json()["next_offset"]
return output
with gr.Blocks() as demo:
knowledge_bases = gr.CheckboxGroup(
label="Select Knowledge Bases",
choices=["skills", "demographics", "linked_in", "showreels"],
)
chatbot = gr.ChatInterface(
fn=send_message_call,
chatbot=gr.Chatbot(height=300),
additional_inputs=knowledge_bases,
)
demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD")))
# demo.launch()
# Set it back to empty when a session is done
# Is there a better way?
threads = {}