Spaces:
Sleeping
Sleeping
import glob | |
import gradio as gr | |
from uuid import uuid4 as uuid | |
from huggingface_hub import HfApi | |
from typing import Any | |
from dotenv import load_dotenv | |
from griptape.structures import Agent | |
from griptape.tasks import PromptTask, StructureRunTask, ToolkitTask | |
from griptape.drivers import LocalConversationMemoryDriver, GriptapeCloudStructureRunDriver, GriptapeCloudEventListenerDriver, LocalFileManagerDriver, LocalStructureRunDriver | |
from griptape.memory.structure import ConversationMemory | |
from griptape.tools import StructureRunClient, TaskMemoryClient, FileManager | |
from griptape.rules import Rule, Ruleset | |
from griptape.config import AnthropicStructureConfig | |
from griptape.events import EventListener, FinishStructureRunEvent | |
import time | |
import os | |
#Load environment variables | |
load_dotenv() | |
#Create an agent that will create a prompt that can be used as input for the query agent from the Griptape Cloud. | |
#Function that logs user history - adds to history parameter of Gradio | |
#TODO: Figure out the exact use of this function | |
def user(user_message, history): | |
history.append([user_message, None]) | |
return ("", history) | |
#Function that logs bot history - adds to the history parameter of Gradio | |
#TODO: Figure out the exact use of this function | |
def bot(history): | |
response = send_message(history[-1][0]) | |
history[-1][1] = "" | |
for character in response: | |
history[-1][1] += character | |
time.sleep(0.005) | |
yield history | |
def create_prompt_task(session_id:str, message:str) -> PromptTask: | |
return PromptTask( | |
f""" | |
Re-structure the values from the user's questions: '{message}' and the input value from the conversation memory '{session_id}.json' to fit the following format. Leave out attributes that aren't important to the user: | |
years experience: <x> | |
location: <x> | |
role: <x> | |
skills: <x> | |
expected salary: <x> | |
availability: <x> | |
past companies: <x> | |
past projects: <x> | |
show reel details: <x> | |
""", | |
) | |
def build_talk_agent(session_id:str,message:str) -> Agent: | |
ruleset = Ruleset( | |
name="Local Gradio Agent", | |
rules=[ | |
Rule( | |
value = "You are responsible for structuring a user's questions into a specific format for a query." | |
), | |
Rule( | |
value = "You ask the user follow-up questions to fill in missing information for the format you are trying to fit." | |
), | |
Rule( | |
value="If the user has no preference for a specific attribute, then you can remove it from the query." | |
), | |
Rule( | |
value="Only return the current query structure and any questions to fill in missing information." | |
), | |
] | |
) | |
file_manager_tool = FileManager( | |
name="FileManager", | |
file_manager_driver=LocalFileManagerDriver(), | |
off_prompt=False | |
) | |
return Agent( | |
config= AnthropicStructureConfig(), | |
conversation_memory=ConversationMemory( | |
driver=LocalConversationMemoryDriver( | |
file_path=f'{session_id}.json' | |
)), | |
tools=[file_manager_tool], | |
tasks=[create_prompt_task(session_id,message)], | |
rulesets=[ruleset], | |
) | |
# Creates an agent for each run | |
# The agent uses local memory, which it differentiates between by session_hash. | |
def build_agent(session_id:str,message:str) -> Agent: | |
ruleset = Ruleset( | |
name="Local Gradio Agent", | |
rules=[ | |
Rule( | |
value = "You are responsible for structuring a user's questions into a specific format for a query and then querying." | |
), | |
Rule( | |
value="Only return the result of the query, do not provide additional commentary." | |
), | |
Rule( | |
value="Only perform one task at a time." | |
), | |
Rule( | |
value="Do not perform the query unless the user has said 'Done' with formulating." | |
), | |
Rule( | |
value="Only perform the query with the proper query structure as one string argument." | |
), | |
Rule( | |
value="If you reformulate the query, then you must ask the user if they are 'Done' again." | |
) | |
] | |
) | |
print("Base URL", os.environ.get("BASE_URL","https://cloud.griptape.ai")) | |
query_client = StructureRunClient( | |
name="QueryResumeSearcher", | |
description="Use it to search for a candidate with the query.", | |
driver = GriptapeCloudStructureRunDriver( | |
#base_url=os.environ.get("BASE_URL","https://cloud.griptape.ai"), | |
structure_id=os.getenv("GT_STRUCTURE_ID"), | |
api_key=os.getenv("GT_CLOUD_API_KEY"), | |
structure_run_wait_time_interval=5, | |
structure_run_max_wait_time_attempts=30 | |
), | |
) | |
talk_client = StructureRunClient( | |
name="FormulateQueryFromUser", | |
description="Used to formulate a query from the user's input.", | |
driver=LocalStructureRunDriver( | |
structure_factory_fn=lambda: build_talk_agent(session_id,message), | |
) | |
) | |
return Agent( | |
config= AnthropicStructureConfig(), | |
conversation_memory=ConversationMemory( | |
driver=LocalConversationMemoryDriver( | |
file_path=f'{session_id}.json' | |
)), | |
tools=[talk_client,query_client], | |
rulesets=[ruleset], | |
) | |
def send_message(message:str, history, request:gr.Request) -> Any: | |
if request: | |
session_hash = request.session_hash | |
agent = build_agent(session_hash,message) | |
response = agent.run(message) | |
return response.output.value | |
demo = gr.ChatInterface( | |
fn=send_message | |
) | |
demo.launch() | |