Yannael_LB
Handling of session states
1c34aa2
raw
history blame
6.47 kB
# -*- coding: utf-8 -*-
"""OpenAI_Assistant_WanderLust.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1DLYjk07uQwF_Oqav1NtqMMCz_S2f0q8B
# OpenAI assistant Wanderlust
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API.
This space is inspired by the implementation of Fanilo Andrianasolo using Streamlit - https://www.youtube.com/watch?v=tLeqCDKgEDU
"""
#!pip install -q -U gradio openai datasets
import gradio as gr
import random
import openai
import os
import json
import plotly.graph_objects as go
import time
os.environ["OPENAI_API_KEY"] = "sk-..." # Replace with your key
assistant_id = "asst_7OC3NTeyCjEZrApdLRklplE7"
client = openai.OpenAI()
assistant = client.beta.assistants.retrieve(assistant_id)
#######################################
# TOOLS SETUP
#######################################
def update_map_state(map_metadata_state, latitude, longitude, zoom):
"""OpenAI tool to update map in-app
"""
map_metadata_state["latitude"] = latitude
map_metadata_state["longitude"] = longitude
map_metadata_state["zoom"] = zoom
print(map_metadata_state)
return "Map updated"
def add_markers_state(map_metadata_state, latitudes, longitudes, labels):
"""OpenAI tool to update markers in-app
"""
map_metadata_state["lat"] = latitudes
map_metadata_state["lon"] = longitudes
map_metadata_state["text"] = labels
print(map_metadata_state)
return "Markers added"
tool_to_function = {
"update_map": update_map_state,
"add_markers": add_markers_state
}
## Helpers
def submit_message(assistant_id, thread_id, user_message):
client.beta.threads.messages.create(
thread_id=thread_id, role="user", content=user_message
)
run = client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id,
)
return run
def get_run_info(run_id, thread_id):
run = client.beta.threads.runs.retrieve(
thread_id=thread_id,
run_id=run_id,
)
return run
#######################################
# INITIAL DATA FOR MAP
#######################################
map_metadata = {
"latitude": 48.85,
"longitude": 2.35,
"zoom": 12,
"lat": [],
"lon": [],
"text": [],
}
fig = go.Figure(go.Scattermapbox())
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
center=go.layout.mapbox.Center(
lat=map_metadata["latitude"],
lon=map_metadata["longitude"]
),
zoom=map_metadata["zoom"]
),
)
def respond(message, chat_history, thread, map_metadata_state):
if thread is None:
thread = client.beta.threads.create()
print(map_metadata_state)
run = submit_message(assistant.id, thread.id, message)
completed = False
# Polling
while not completed:
run = get_run_info(run.id, thread.id)
if run.status == "requires_action":
tools_output = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
f = tool_call.function
f_name = f.name
f_args = json.loads(f.arguments)
#print(f"Launching function {f_name} with args {f_args}")
tool_result = tool_to_function[f_name](map_metadata_state,**f_args)
tools_output.append(
{
"tool_call_id": tool_call.id,
"output": tool_result,
}
)
#print(f"Will submit {tools_output}")
client.beta.threads.runs.submit_tool_outputs(
thread_id=thread.id,
run_id=run.id,
tool_outputs=tools_output,
)
if run.status == "completed":
completed = True
else:
time.sleep(0.1)
dialog = [
[m.role, m.content[0].text.value]
for m in client.beta.threads.messages.list(thread.id, order="asc").data
]
formatted_dialog = []
for i in range(int(len(dialog)/2)):
formatted_dialog.append([dialog[i*2][1],dialog[i*2+1][1]])
chat_history = formatted_dialog
print(formatted_dialog)
fig = None
if len(map_metadata_state["lat"])==0:
fig = go.Figure(go.Scattermapbox())
else :
fig = go.Figure(go.Scattermapbox(
customdata=map_metadata_state["text"],
lat=map_metadata_state["lat"],
lon=map_metadata_state["lon"],
mode='markers',
marker=go.scattermapbox.Marker(
size=18
),
hoverinfo="text",
hovertemplate='<b>Name</b>: %{customdata}'
))
fig.update_layout(
mapbox_style="open-street-map",
hovermode='closest',
mapbox=dict(
center=go.layout.mapbox.Center(
lat=map_metadata_state["latitude"],
lon=map_metadata_state["longitude"]
),
zoom=12
),
)
return "", chat_history, fig, thread, map_metadata_state
with gr.Blocks(title="OpenAI assistant Wanderlust") as demo:
gr.Markdown("# OpenAI assistant Wanderlust")
thread = gr.State()
map_metadata_state = gr.State(map_metadata)
with gr.Column():
with gr.Row():
chatbot = gr.Chatbot()
map = gr.Plot(fig)
msg = gr.Textbox("Move the map to Brussels and add markers on shops with the best waffles")
with gr.Column():
with gr.Row():
submit = gr.Button("Submit")
clear = gr.ClearButton([msg, chatbot])
msg.submit(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state])
submit.click(respond, [msg, chatbot, thread, map_metadata_state], [msg, chatbot, map, thread, map_metadata_state])
gr.Markdown(
"""
# Description
Basic recreation of OpenAI's DevDay Wanderlust demo app. It relies on Gradio and the new Assistants API. [Github repository](https://github.com/Yannael/openai-assistant-wanderlust)
This space is inspired by the implementation of [Fanilo Andrianasolo using Streamlit](https://www.youtube.com/watch?v=tLeqCDKgEDU)
"""
)
demo.queue()
demo.launch(debug=True)