Kate Forsberg commited on
Commit
02585f2
·
1 Parent(s): a6c9658

updated to include KB selector

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -128,7 +128,7 @@ def build_talk_agent(session_id: str, message: str) -> Agent:
128
 
129
  # Creates an agent for each run
130
  # The agent uses local memory, which it differentiates between by session_hash.
131
- def build_agent(session_id: str, message: str) -> Agent:
132
 
133
  create_thread_id(session_id)
134
 
@@ -159,7 +159,9 @@ def build_agent(session_id: str, message: str) -> Agent:
159
 
160
  query_client = StructureRunTool(
161
  name="QueryResumeSearcher",
162
- description="Use it to search for a candidate with the query.",
 
 
163
  driver=GriptapeCloudStructureRunDriver(
164
  structure_id=os.getenv("GT_STRUCTURE_ID"),
165
  api_key=os.getenv("GT_CLOUD_API_KEY"),
@@ -186,15 +188,16 @@ def build_agent(session_id: str, message: str) -> Agent:
186
  )
187
 
188
 
189
- def send_message(message: str, history, request: gr.Request) -> Any:
190
  if request:
191
  session_hash = request.session_hash
192
- agent = build_agent(session_hash, message)
193
  response = agent.run(message)
194
  return response.output.value
195
 
196
-
197
- demo = gr.ChatInterface(fn=send_message)
 
198
  demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD")))
199
  # demo.launch(share=True)
200
 
 
128
 
129
  # Creates an agent for each run
130
  # The agent uses local memory, which it differentiates between by session_hash.
131
+ def build_agent(session_id: str, message: str, kbs:str) -> Agent:
132
 
133
  create_thread_id(session_id)
134
 
 
159
 
160
  query_client = StructureRunTool(
161
  name="QueryResumeSearcher",
162
+ description=f"""Use it to search for a candidate with the query.
163
+ Add this as another argument after the input: {kbs}
164
+ """,
165
  driver=GriptapeCloudStructureRunDriver(
166
  structure_id=os.getenv("GT_STRUCTURE_ID"),
167
  api_key=os.getenv("GT_CLOUD_API_KEY"),
 
188
  )
189
 
190
 
191
+ def send_message(message: str, history, knowledge_bases, request: gr.Request) -> Any:
192
  if request:
193
  session_hash = request.session_hash
194
+ agent = build_agent(session_hash, message, str(knowledge_bases))
195
  response = agent.run(message)
196
  return response.output.value
197
 
198
+ with gr.Blocks() as demo:
199
+ knowledge_bases = gr.CheckboxGroup(choices=["skills","demographics","linked_in","showreels"])
200
+ chatbot = gr.ChatInterface(fn=send_message, additional_inputs=knowledge_bases)
201
  demo.launch(auth=(os.environ.get("GRADIO_USERNAME"), os.environ.get("GRADIO_PASSWORD")))
202
  # demo.launch(share=True)
203